Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #
- # database.py
- #
- # Created by Adi Unnithan on 5/5/17.
- # Copyright © 2017 Adi Unnithan. All rights reserved.
- #
- from typing import List
- import psycopg2
- import psycopg2.extras
- from psycopg2 import sql
- from psycopg2 import extensions as ext
- # pylint: disable=R0903
- class DatabasePredicateItem(object):
- def __init__(self, key, op, val):
- self.key = key
- self.op = op
- self.val = val
- # Generates: 'id = %s'
- def get_sql(self):
- return sql.SQL("{0}{1}{2}").format(
- sql.SQL(self.key),
- sql.SQL(self.op),
- sql.Placeholder()
- )
- class DatabasePredicate(object):
- limitAll = ext.AsIs('ALL')
- limitOne = '1'
- def __init__(self, predicateItems: List[DatabasePredicateItem], limit):
- self.predicateItems = predicateItems
- self.limit = limit
- # Generates: 'id = %s and name = %s... limit %s'
- def get_sql(self):
- if self.predicateItems.count == 0:
- return "1=1"
- if self.limit:
- return sql.SQL('{0} limit {1}').format(
- sql.SQL('and ').join(pi.get_sql() for pi in self.predicateItems),
- sql.Placeholder()
- )
- else:
- return sql.SQL('{0}').format(
- sql.SQL('and ').join(pi.get_sql() for pi in self.predicateItems)
- )
- def get_placeholder_values(self):
- vals = tuple(pi.val for pi in self.predicateItems)
- if self.limit:
- return vals + (self.limit,)
- else:
- return vals
- def set_limit_one(self):
- self.limit = self.limitOne
- def is_limit_one(self):
- return self.limit == self.limitOne
- @staticmethod
- def empty(limit=limitAll):
- return DatabasePredicate([], limit)
- @staticmethod
- def withTuple(tupleInst: tuple, limit=limitAll):
- return DatabasePredicate.withTuples([tupleInst], limit)
- @staticmethod
- def withTuples(tuples: List[tuple], limit=limitAll):
- items = [] # type: List[DatabasePredicateItem]
- for t in tuples:
- items.append(DatabasePredicateItem(t[0], t[1], t[2]))
- return DatabasePredicate(items, limit)
- class DatabaseConnection(object):
- def __init__(self, config):
- self.__dbconn = DatabaseConnection.__create_database_connection(config)
- self.__cursor = None # type: psycopg2.psycopg1.cursor
- def open_cursor(self):
- self.__cursor = self.__dbconn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
- def close_cursor(self):
- self.__cursor.close()
- # TODO: allow *args (list of tuples) to be passed in to define the predicate
- def select_one(self, table_name, predicate: DatabasePredicate):
- predicate.set_limit_one()
- return self.select(table_name, predicate)
- def select(self, table_name, predicate: DatabasePredicate):
- assert self.__cursor is not None
- sqlformat = sql.SQL("select * from {0} where {1}").format(
- sql.Identifier(table_name),
- predicate.get_sql()
- )
- query = self.__cursor.mogrify(sqlformat, predicate.get_placeholder_values())
- print(query)
- self.__cursor.execute(query)
- result = self.__cursor.fetchall()
- if result and predicate.is_limit_one():
- return result[0]
- return result
- def insert(self, table_name, obj):
- assert self.__cursor is not None
- sqlformat = sql.SQL("insert into {0} ({1}) values ({2}) returning *").format(
- sql.Identifier(table_name),
- sql.SQL(', ').join(map(sql.Identifier, obj.keys())),
- sql.SQL(', ').join(sql.Placeholder() * len(obj.values())))
- query = self.__cursor.mogrify(sqlformat, tuple(obj.values()))
- print(query)
- self.__cursor.execute(query)
- result = self.__cursor.fetchone()
- return result
- def update(self, table_name, obj, predicate: DatabasePredicate):
- assert self.__cursor is not None
- sqlformat = sql.SQL("update {0} set {1} where {2}").format(
- sql.Identifier(table_name),
- sql.SQL('AND ').join(sql.SQL("{0} = {1}").format(sql.Identifier(k), sql.Placeholder()) for k, v in obj.items()),
- predicate.get_sql()
- )
- query = self.__cursor.mogrify(sqlformat, tuple(obj.values()) + predicate.get_placeholder_values())
- print(query)
- self.__cursor.execute(query)
- return self.select_one(table_name, predicate)
- @staticmethod
- def __create_database_connection(config):
- connection = psycopg2.connect(host=config.db_host, database=config.db_name, user=config.db_user, password=config.db_pass)
- connection.set_session(autocommit=True)
- return connection
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement