Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # MIT License
- # Copyright (c) 2019 Bellhops Inc.
- # Permission is hereby granted, free of charge, to any person obtaining a copy
- # of this software and associated documentation files (the "Software"), to deal
- # in the Software without restriction, including without limitation the rights
- # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
- # copies of the Software, and to permit persons to whom the Software is
- # furnished to do so, subject to the following conditions:
- # The above copyright notice and this permission notice shall be included in all
- # copies or substantial portions of the Software.
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
- # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
- # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
- # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
- # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
- # SOFTWARE.
- import os
- import datetime
- import logging
- import sys
- import botocore
- from sqlalchemy import MetaData
- from sqlalchemy import Table
- from geoalchemy2 import Table as GeoTable
- from geoalchemy2.types import Geography
- from sqlalchemy import Column, NVARCHAR, NUMERIC
- from sqlalchemy.schema import CreateTable
- from sqlalchemy.sql.sqltypes import TEXT, NullType, ARRAY
- from sqlalchemy.dialects import postgresql
- from sqlalchemy.dialects.postgresql.base import DOUBLE_PRECISION
- from airflow.hooks.postgres_hook import PostgresHook
- from airflow.hooks.S3_hook import S3Hook
- from airflow.models import Variable
- from airflow.operators.subdag_operator import SubDagOperator
- from airflow.operators.python_operator import PythonOperator
- from airflow import utils as airflow_utils
- class Pump(object):
- def __init__(self, table_name, origin_conn_id, origin_schema_name,
- destination_conn_id, destination_schema_name, destination_db_type, s3_conn_id,
- s3_bucket, s3_directory, drop_destination_table=False):
- self.table_name = table_name
- self.origin_hook = PostgresHook(postgres_conn_id=origin_conn_id)
- self.origin_conn_uri = self.origin_hook.get_uri()
- self.origin_schema_name = origin_schema_name
- self.destination_hook = PostgresHook(postgres_conn_id=destination_conn_id)
- self.destination_schema_name = destination_schema_name
- self.destination_db_type = destination_db_type
- self.s3_hook = S3Hook(aws_conn_id=s3_conn_id)
- self.drop_destination_table = drop_destination_table
- self.s3_bucket = s3_bucket
- self.s3_directory = s3_directory
- self.origin_engine = self.origin_hook.get_sqlalchemy_engine()
- self.origin_metadata = MetaData(schema=self.origin_schema_name)
- self.origin_table_name_with_schema = '.'.join([self.origin_schema_name, self.table_name])
- self.destination_table_name_with_schema = '.'.join([self.destination_schema_name, self.table_name])
- self.bucket_name = self.s3_bucket
- self.directory_name = self.s3_directory
- self.s3_file_name = '''{directory_name}/{table_name}'''.format(directory_name=self.directory_name,
- table_name=table_name) + '_{:%Y-%m-%dT%H%M}'.format(datetime.datetime.now()) + '.csv'
- self.destination_conn_uri = self.destination_hook.get_uri()
- self.destination_engine = self.destination_hook.get_sqlalchemy_engine()
- self.aws_key, self.aws_pass, _, _ = self.s3_hook._get_credentials(None)
- self.local_file = None
- @property
- def origin_has_table(self):
- return self.origin_engine.has_table(self.table_name, schema=self.origin_schema_name)
- @property
- def destination_has_table(self):
- return self.destination_engine.has_table(self.table_name, schema=self.destination_schema_name)
- @property
- def is_destination_type_redshift(self):
- return self.destination_db_type == 'redshift'
- def get_columns_from_table(self, table):
- logging.info("Getting columns for {table}.".format(table=table.name))
- columns = []
- for column in table.columns:
- if self.is_destination_type_redshift:
- logging.info("Changing column constraints for {table}.".format(table=table.name))
- if column.name == 'id':
- columns.append(Column(column.name, column.type, primary_key=True, autoincrement=False))
- elif type(column.type) is TEXT:
- columns.append(Column(column.name, NVARCHAR(65535)))
- elif type(column.type) is DOUBLE_PRECISION:
- columns.append(Column(column.name, NUMERIC(20, 5)))
- elif type(column.type) is NullType:
- columns.append(Column(column.name, NVARCHAR(65535)))
- elif type(column.type) is Geography:
- columns.append(Column(column.name, NVARCHAR(65535)))
- else:
- columns.append(Column(column.name, column.type))
- else:
- if type(column.type) is DOUBLE_PRECISION:
- columns.append(Column(column.name, NUMERIC(20, 5)))
- else:
- columns.append(Column(column.name, column.type))
- return columns
- def execute_psql_command(self, conn_uri, command, password=None):
- if password:
- psql_query = '''export PGPASSWORD='{password}'; psql "{conn_url}" -c "{command}" '''.format(
- conn_url=conn_uri, command=command, password=password)
- print(psql_query)
- os.system(psql_query)
- else:
- psql_query = '''psql "{conn_url}" -c "{command}" '''.format(
- conn_url=conn_uri, command=command)
- print(psql_query)
- res = os.system(psql_query)
- if res != 0:
- sys.exit(1)
- def create_directory(self, directory):
- os.system('''mkdir -p {directory}'''.format(directory=directory))
- return directory
- def remove_local_file(self):
- command = '''rm -f {file_path}'''.format(file_path=self.local_file)
- logging.info("Removing file with command: {command}".format(command=command))
- os.system(command)
- def copy_table_to_local_csv_command(self, table_name):
- pump_tmp_directory = Variable.get('PUMP_TMP_DIRECTORY', default_var='/tmp/pump')
- directory = self.create_directory("{pump_tmp_directory}/{destination_schema_name}".format(
- destination_schema_name=self.destination_schema_name,
- pump_tmp_directory=pump_tmp_directory
- ))
- file_ = "{directory}/{table_name}".format(directory=directory, table_name=table_name) + '_{:%Y-%m-%dT%H%M}'.format(datetime.datetime.now()) + '.csv'
- command = '''\COPY (SELECT * FROM {table_name}) TO '{file}' HEADER CSV;'''.format(
- table_name=table_name,
- file=file_
- )
- self.local_file = file_
- logging.info("Dumped table {table_name} to file {file_name}".format(table_name=self.table_name, file_name=self.local_file))
- return command
- def copy_s3_file_to_table_command(self, table_name, schema_name, bucket_name, file_name, aws_key, aws_pass):
- table_name = '.'.join([schema_name, table_name])
- command = '''COPY {table_name} FROM 's3://{bucket_name}/{file}' credentials 'aws_access_key_id={aws_key};aws_secret_access_key={aws_pass}' IGNOREHEADER 1 CSV;'''.format(
- table_name=table_name,
- bucket_name=bucket_name,
- file=file_name,
- aws_key=aws_key,
- aws_pass=aws_pass
- )
- return command
- def copy_local_file_to_table_command(self, table_name, schema_name, file_name):
- table_name = '.'.join([schema_name, table_name])
- command = '''\COPY {table_name} FROM '{file}' HEADER CSV;'''.format(
- table_name=table_name,
- file=file_name,
- )
- return command
- def copy_table_to_local_file(self):
- command = self.copy_table_to_local_csv_command(self.origin_table_name_with_schema)
- conn = self.origin_hook.get_connection(self.origin_hook.postgres_conn_id)
- conn_uri = 'postgres://{user}@{host}:{port}/{schema}'.format(user=conn.login,
- host=conn.host,
- port=conn.port,
- schema=conn.schema)
- self.execute_psql_command(conn_uri, command, conn.password)
- def copy_table_to_s3(self):
- logging.info("Copying local file to s3 bucket:{bucket} file: {file}".format(bucket=self.bucket_name,
- file=self.s3_file_name))
- self.s3_hook.load_file(
- self.local_file,
- self.s3_file_name,
- bucket_name=self.bucket_name,
- replace=True,
- )
- def remove_s3_file(self):
- logging.info("Removing file from s3 bucket:{bucket} file: {file}".format(bucket=self.bucket_name,
- file=self.s3_file_name))
- boto_client = self.s3_hook.get_conn()
- try:
- boto_client.head_object(Bucket=self.s3_bucket, Key=self.s3_file_name)
- boto_client.delete_object(Bucket=self.s3_bucket, Key=self.s3_file_name)
- except botocore.exceptions.ClientError as e:
- logging.info("Could not find file. Error {}".format(e))
- def copy_s3_file_to_table(self):
- logging.info("Copying file from s3 bucket:{bucket} file: {file} to table {table_name}".format(
- bucket=self.bucket_name,
- file=self.s3_file_name,
- table_name=self.table_name))
- command = self.copy_s3_file_to_table_command(table_name=self.table_name,
- schema_name=self.destination_schema_name,
- bucket_name=self.bucket_name,
- file_name=self.s3_file_name,
- aws_key=self.aws_key,
- aws_pass=self.aws_pass)
- self.destination_hook.run(command)
- def copy_local_file_to_table(self):
- logging.info("Copying local file {local_file} to {table_name}".format(
- table_name=self.table_name,
- local_file=self.local_file
- ))
- command = self.copy_local_file_to_table_command(table_name=self.table_name,
- schema_name=self.destination_schema_name,
- file_name=self.local_file)
- self.execute_psql_command(self.destination_conn_uri, command)
- def ddl_statement_create_table(self):
- if self.is_destination_type_redshift:
- table = Table(self.table_name, self.origin_metadata, autoload=True, autoload_with=self.origin_engine)
- else:
- table = GeoTable(self.table_name, self.origin_metadata, autoload=True, autoload_with=self.origin_engine)
- columns = self.get_columns_from_table(table)
- destination_table = Table(self.table_name, MetaData(schema=self.destination_schema_name), *columns)
- created_table = CreateTable(destination_table).compile(dialect=postgresql.dialect())
- created_table_ddl = str(created_table)
- logging.info("Created table: {created_table_ddl}".format(created_table_ddl=created_table_ddl))
- return created_table_ddl
- def ddl_statement_truncate_table(self):
- truncate_table_ddl = '''TRUNCATE TABLE {table_name}'''.format(table_name=self.destination_table_name_with_schema)
- return truncate_table_ddl
- def ddl_statement_drop_table(self):
- drop_table_ddl = '''DROP TABLE IF EXISTS {table_name} CASCADE'''.format(table_name=self.destination_table_name_with_schema)
- return drop_table_ddl
- def create_table(self):
- create_ddl_statement = self.ddl_statement_create_table()
- truncate_ddl_statement = self.ddl_statement_truncate_table()
- drop_ddl_statement = self.ddl_statement_drop_table()
- if self.drop_destination_table:
- self.destination_hook.run(drop_ddl_statement)
- elif self.destination_has_table:
- self.destination_hook.run(truncate_ddl_statement)
- if not self.destination_has_table:
- self.destination_hook.run(create_ddl_statement)
- def create_and_load_table(self):
- if self.origin_has_table:
- self.copy_table_to_local_file()
- self.create_table()
- if self.is_destination_type_redshift:
- self.copy_table_to_s3()
- self.copy_s3_file_to_table()
- self.remove_s3_file()
- else:
- self.copy_local_file_to_table()
- self.remove_local_file()
- else:
- logging.error("Origin hook {origin_hook} does not have table {table} in schema {schema}".format(
- origin_hook=self.origin_hook,
- table=self.table_name,
- schema=self.origin_schema_name
- ))
- sys.exit(1)
- class PumpSubDagOperator(SubDagOperator):
- @airflow_utils.apply_defaults
- def __init__(self, dag, task_id, start_date, schedule_interval, default_args, table_names, pump_config, **kwargs):
- self.start_date = start_date
- self.dag_schedule_interval = schedule_interval
- self.default_args = default_args
- self.table_names = table_names
- self.origin_conn_id = pump_config['origin_conn_id']
- self.destination_conn_id = pump_config['destination_conn_id']
- self.origin_schema_name = pump_config['origin_schema_name']
- self.destination_schema_name = pump_config['destination_schema_name']
- self.destination_db_type = pump_config['destination_db_type']
- self.s3_conn_id = pump_config['s3_conn_id']
- self.s3_bucket = pump_config['s3_bucket']
- self.s3_directory = pump_config['s3_directory']
- if 'drop_destination_table' in pump_config:
- self.drop_destination_table = bool(pump_config['drop_destination_table'])
- else:
- self.drop_destination_table = False
- from airflow import DAG # circular import
- self.sub_dag_name = dag.dag_id + '.' + task_id
- self.subdag = DAG(
- self.sub_dag_name,
- start_date=self.start_date,
- schedule_interval=self.dag_schedule_interval,
- default_args=self.default_args
- )
- self.init_tasks()
- super(PumpSubDagOperator, self).__init__(
- dag=dag,
- subdag=self.subdag,
- task_id=task_id,
- trigger_rule='all_done'
- )
- @property
- def task_type(self):
- return 'SubDagOperator'
- def pump_table(self, table_name, **kwargs):
- pump = Pump(
- table_name=table_name,
- origin_conn_id=self.origin_conn_id,
- destination_conn_id=self.destination_conn_id,
- origin_schema_name=self.origin_schema_name,
- destination_schema_name=self.destination_schema_name,
- destination_db_type=self.destination_db_type,
- s3_conn_id=self.s3_conn_id,
- s3_bucket=self.s3_bucket,
- s3_directory=self.s3_directory,
- drop_destination_table=self.drop_destination_table
- )
- pump.create_and_load_table()
- def create_task(self, table_name):
- task = PythonOperator(
- task_id="pump_{table_name}".format(table_name=table_name),
- python_callable=self.pump_table,
- op_kwargs={
- "table_name": table_name
- },
- dag=self.subdag
- )
- def init_tasks(self):
- for table_name in self.table_names:
- self.create_task(table_name)
- if __name__ == '__main__':
- pump = Pump(
- table_name='test',
- origin_conn_id='airflow_origin_conn_name',
- origin_schema_name='origin_schema',
- destination_conn_id='airflow_destination_conn_name',
- destination_schema_name='destination_schema',
- destination_db_type='postgres',
- s3_conn_id='airflow_s3_conn_name',
- s3_bucket='bucket_name',
- s3_directory='directory_name_in_bucket',
- drop_destination_table=True
- )
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement