Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #!/usr/bin/env python3
- import os
- import sys
- import subprocess
- import argparse
- import psycopg2
- from app_config import cfg
- from colorama import Fore, Back, Style
- class Migrator:
- conn = psycopg2.connect("dbname={0} user={1} password={2} host=127.0.0.1 port={3}".format(cfg['db_name'], cfg['db_user'], cfg['db_pass'], cfg['db_port']))
- def get_version(self):
- """ Gets version number from db """
- try:
- conn = self.conn
- cur = conn.cursor()
- cur.execute("SELECT version from version order by id asc limit 1;")
- version = cur.fetchone()[0]
- return version
- except Exception as err:
- print("{1}\nERROR Could not get current version number.\nStacktrace ==>\n{0}\n{1}".format(err, "~"*50))
- raise err
- def update_version_number(self, new_version):
- """ Updates version number in version table"""
- conn = self.conn
- cur = conn.cursor()
- cur.execute("UPDATE version SET version = '{0}' WHERE id = 1;".format(new_version))
- conn.commit()
- def execute_sql(self, scripts):
- """ Executes SQL script line by line, exits if error is encountered"""
- conn = self.conn
- cur = conn.cursor()
- try:
- for script in scripts:
- print(Fore.CYAN + "EXECUTING SCRIPT {0}".format(script))
- with open("{0}{1}".format(cfg['sql_dir'], script), 'r') as s:
- sql_content = s.read()
- statements = [stmt.strip() for stmt in sql_content.split(";") if stmt.strip() != '' and not stmt.startswith('--')]
- for cmd in statements:
- print(Fore.WHITE+cmd)
- cur.execute(cmd)
- conn.commit()
- print(Fore.GREEN+"{0}........PASSED".format(script))
- self.update_version_number(script[-1])
- except Exception as e:
- print(Fore.RED+"...FAILED ! Investigate the stacktrace above")
- print(e)
- finally:
- #reset terminal colours
- print(Fore.WHITE+"")
- def ensure_version_exists(self):
- check_table = """ SELECT EXISTS (
- SELECT 1
- FROM pg_tables
- WHERE schemaname = 'public'
- AND tablename = 'version'
- ); """
- create_query = """ CREATE TABLE version(
- id serial not null,
- version varchar(50) not null
- );
- COMMIT;
- insert into version (version) values (''); """
- conn = self.conn
- cur = conn.cursor()
- cur.execute(check_table)
- table_exists = cur.fetchone()
- print(table_exists[0])
- if not table_exists[0]:
- cur.execute(create_query)
- conn.commit()
- def do_update(self, args):
- self.ensure_version_exists()
- if args['file']:
- if os.path.exists("{0}{1}".format(cfg['sql_dir'], args['file'])):
- to_execute = [f for f in os.listdir('{0}'.format(cfg['sql_dir'])) if f == args['file']]
- else:
- sys.exit(Fore.RED + "ERROR : Invalid script name")
- elif args['all']:
- to_execute = sorted([f for f in os.listdir('{0}'.format(cfg['sql_dir']))])
- elif args['current']:
- from_version = self.get_version()
- to_execute = sorted([f for f in os.listdir('{0}'.format(cfg['sql_dir'])) if f > from_version])
- print(Fore.YELLOW+"Scripts to execute => "+str(to_execute)+Fore.WHITE)
- if len(to_execute) > 0:
- self.execute_sql(to_execute)
- self.conn.close()
- if __name__ == '__main__':
- parser = argparse.ArgumentParser(description="Updates DB schemas from given version to latest")
- parser.add_argument('-c','--current', action='store_true', help="upgrade starting from current version" ,required = False)
- parser.add_argument('-a','--all', action='store_true', help="run all the sql scripts", required = False)
- parser.add_argument('-f','--file', type=str, help="Run a single file by its name" ,required=False)
- args = vars(parser.parse_args())
- if args['current'] and args['all']:
- sys.exit(Fore.RED + "ERROR : Specify either --current or --all, not both")
- migrator = Migrator()
- migrator.do_update(args)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement