Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import re
- from os import path
- class DaoEnvironment:
- def __init__(self, rep_path, rep_package, int_path, int_package, impl_path, impl_package):
- self.rep_path = rep_path
- self.rep_package = rep_package
- self.int_path = int_path
- self.int_package = int_package
- self.impl_path = impl_path
- self.impl_package = impl_package
- def generate_rep(self, model_name, model_import, id_class, id_import, paging):
- rep_file_path = self.rep_path + path.sep + model_name + "Repository" + ".java"
- if not path.exists(rep_file_path):
- with open(rep_file_path, "wt") as output:
- print("package {0};".format(self.rep_package), file=output)
- print("", file=output)
- print("import {0};".format(model_import), file=output)
- if paging:
- print("import org.springframework.data.repository.PagingAndSortingRepository;", file=output)
- else:
- print("import org.springframework.data.repository.CrudRepository;", file=output)
- print("import org.springframework.stereotype.Repository;", file=output)
- if id_import is not None:
- print("import {0};".format(id_import), file=output)
- print("", file=output)
- print("@Repository", file=output)
- if paging:
- print("public interface {0}Repository extends PagingAndSortingRepository<{0}, {1}>"
- .format(model_name, id_class), "{", file=output)
- else:
- print("public interface {0}Repository extends CrudRepository<{0}, {1}>"
- .format(model_name, id_class), "{", file=output)
- print("", file=output)
- print("}", file=output)
- print("Generated repository file for", model_name)
- def generate_int(self, model_name, model_import):
- int_file_path = self.int_path + path.sep + model_name + "Service" + ".java"
- if not path.exists(int_file_path):
- with open(int_file_path, "wt") as output:
- print("package {0};".format(self.int_package), file=output)
- print("", file=output)
- print("import {0};".format(model_import), file=output)
- print("", file=output)
- print("public interface {0}Service".format(model_name), "{", file=output)
- print("", file=output)
- print("}", file=output)
- print("Generated interface file for", model_name)
- def generate_impl(self, model_name, model_import, log_generation):
- impl_file_path = self.impl_path + path.sep + model_name + "Service" + ".java"
- if not path.exists(impl_file_path):
- with open(impl_file_path, "wt") as output:
- print("package {0};".format(self.impl_package), file=output)
- print("", file=output)
- print("import {0};".format(model_import), file=output)
- print("import {0}.{1}Repository;".format(self.rep_package, model_name), file=output)
- if log_generation:
- print("import org.slf4j.Logger;", file=output)
- print("import org.slf4j.LoggerFactory;", file=output)
- print("import org.springframework.beans.factory.annotation.Autowired;", file=output)
- print("import org.springframework.stereotype.Service;", file=output)
- print("", file=output)
- print("@Service", file=output)
- print("public class {0}Service implements {1}.{0}Service".format(model_name, self.int_package), "{",
- file=output)
- if log_generation:
- print("\tprivate static final Logger log = LoggerFactory.getLogger({0}Service.class);".format(
- model_name), file=output)
- print("\tprivate {0}Repository repository;".format(model_name), file=output)
- print("", file=output)
- print("\t@Autowired", file=output)
- print("\tpublic {0}Service({0}Repository repository)".format(model_name), "{", file=output)
- print("\t\tthis.repository = repository;", file=output)
- print("\t}", file=output)
- print("}", file=output)
- print("Generated impl file for", model_name)
- def generate(self, model_path: str, log_generation, paging):
- match = re.search(r"\\([a-zA-Z0-9]+)\.", model_path)
- if not match:
- raise ValueError
- model_name = match.groups()[-1]
- print(model_name)
- imports = dict()
- id_field = False
- model_import = None
- id_class = None
- id_import = None
- with open(model_path, "rt") as inp:
- for line in inp:
- if line.startswith("package"):
- model_import = line.strip()[8:-1] + '.' + model_name
- elif line.startswith("import") and not line.startswith("import static"):
- match = re.search(r"\.?([a-zA-Z0-9]+);", line)
- if match:
- class_name = match.groups()[-1]
- imports[class_name] = line.strip()[7:-1]
- elif "@Id" in line:
- id_field = True
- elif id_field and ";" in line:
- match = re.search(r"([A-Za-z0-9]+) ([A-Za-z0-9]+);", line)
- if match:
- id_class = match.groups()[0]
- if id_class in imports:
- id_import = imports[id_class]
- break
- if model_import is None or id_class is None:
- raise ValueError
- model_import = model_import.strip()
- id_class = id_class.strip()
- if id_import is not None:
- id_import = id_import.strip()
- print(id_class, id_import, model_import)
- self.generate_rep(model_name, model_import, id_class, id_import, paging)
- self.generate_int(model_name, model_import)
- self.generate_impl(model_name, model_import, log_generation)
- if __name__ == '__main__':
- in_rep_path = input("rep_path>")
- in_rep_package = input("rep_package>")
- in_int_path = input("int_path>")
- in_int_package = input("int_package>")
- in_impl_path = input("impl_path>")
- in_impl_package = input("impl_package>")
- dao = DaoEnvironment(in_rep_path, in_rep_package, in_int_path, in_int_package, in_impl_path, in_impl_package)
- in_log = input("add log? (y/n)>").strip().lower() == "y"
- in_paging = input("paging? (y/n)>").strip().lower() == "y"
- print("type 'e' for exit")
- in_path = ''
- while in_path != 'e':
- try:
- in_path = input().strip()
- if in_path == 'e':
- break
- dao.generate(in_path, in_log, in_paging)
- except BaseException:
- pass
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement