Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #!/usr/bin/env python3
- """
- Run migrations on the servatrice database.
- Reads migrations from `servatrice/migrations/` and runs them
- on the database in order.
- Only runs migrations if they are nedded. This is determined using
- the cockatrice_schema_version table.
- Stops running migrations if any fail for any reason.
- """
- import os
- from argparse import ArgumentParser
- import pymysql
- import pymysql.cursors
- SQL_CONTROLLER = None
- def run_sql_command(sql: str) -> str:
- """Run a SQL command."""
- SQL_CONTROLLER.execute(sql)
- result = SQL_CONTROLLER.fetchall()
- return result
- def get_all_migrations(args) -> list:
- """Get a list of all available migrations."""
- files = os.listdir(args.migration_directory)
- migrations = [f'{args.migration_directory}/{x}' for x in files if x.endswith('.sql')]
- migrations.sort()
- return migrations
- def get_schema_version() -> int:
- """Get the schema version of the servatrice database."""
- command = 'SELECT version FROM cockatrice_schema_version;'
- result = run_sql_command(command)
- result = result[0]['version']
- return int(result)
- def valid_migrations(migrations: list, schema_version: int) -> list:
- """Retrun a list of valid migrations."""
- valid = []
- for migration in migrations:
- parts = migration.split('_')
- if schema_version <= int(parts[1]):
- valid.append(migration)
- valid.sort()
- return valid
- def run_migration(migration: str) -> dict:
- """Load a migration from disk and run it."""
- with open(migration, 'r') as f:
- sql = f.read()
- # print(SQL_CONTROLLER.mogrify(sql))
- # exit()
- try:
- result = run_sql_command(sql)
- status = {
- 'success': True,
- 'result': result
- }
- except pymysql.err.MySQLError as exception:
- status = {
- 'success': False,
- 'error': exception
- }
- return status
- def main() -> None:
- """Run the migrations."""
- global SQL_CONTROLLER
- parser = ArgumentParser(
- description='Run migrations on a servatrice database.',
- epilog='Be sure to manually verify migrations _before_ running them!'
- )
- mysql_group = parser.add_argument_group('MySql Server Args')
- mysql_group.add_argument('-u', '--user', required=True)
- mysql_group.add_argument('-p', '--password', '--pass', required=True)
- mysql_group.add_argument('-H', '--host', default='127.0.0.1')
- mysql_group.add_argument('-d', '--database', default='servatrice')
- mysql_group.add_argument('-P', '--port', type=int, default=3306)
- script_group = parser.add_argument_group('Script Args')
- script_group.add_argument('-D', '--migration-directory', default='./migrations')
- script_group.add_argument('--safe-mode', type=bool, default=True)
- args = parser.parse_args()
- connection = pymysql.connect(
- host=args.host, user=args.user, password=args.password,
- db=args.database, charset='utf8mb4',
- cursorclass=pymysql.cursors.DictCursor)
- SQL_CONTROLLER = connection.cursor()
- migrations = get_all_migrations(args)
- schema_version = get_schema_version()
- migrations = valid_migrations(migrations, schema_version)
- for migration in migrations:
- status = run_migration(migration)
- if not status['success']:
- exc = status['error']
- print(exc)
- break
- # TODO handle ctlaltca's concerns from #2969
- if __name__ == '__main__':
- main()
Add Comment
Please, Sign In to add comment