Advertisement
Guest User

Untitled

a guest
Aug 9th, 2017
93
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.63 KB | None | 0 0
  1. #
  2. # database.py
  3. #
  4. # Created by Adi Unnithan on 5/5/17.
  5. # Copyright © 2017 Adi Unnithan. All rights reserved.
  6. #
  7.  
  8. from typing import List
  9. import psycopg2
  10. import psycopg2.extras
  11. from psycopg2 import sql
  12. from psycopg2 import extensions as ext
  13.  
  14. # pylint: disable=R0903
  15. class DatabasePredicateItem(object):
  16. def __init__(self, key, op, val):
  17. self.key = key
  18. self.op = op
  19. self.val = val
  20.  
  21. # Generates: 'id = %s'
  22. def get_sql(self):
  23. return sql.SQL("{0}{1}{2}").format(
  24. sql.SQL(self.key),
  25. sql.SQL(self.op),
  26. sql.Placeholder()
  27. )
  28.  
  29. class DatabasePredicate(object):
  30.  
  31. limitAll = ext.AsIs('ALL')
  32. limitOne = '1'
  33.  
  34. def __init__(self, predicateItems: List[DatabasePredicateItem], limit):
  35. self.predicateItems = predicateItems
  36. self.limit = limit
  37.  
  38. # Generates: 'id = %s and name = %s... limit %s'
  39. def get_sql(self):
  40. if self.predicateItems.count == 0:
  41. return "1=1"
  42. if self.limit:
  43. return sql.SQL('{0} limit {1}').format(
  44. sql.SQL('and ').join(pi.get_sql() for pi in self.predicateItems),
  45. sql.Placeholder()
  46. )
  47. else:
  48. return sql.SQL('{0}').format(
  49. sql.SQL('and ').join(pi.get_sql() for pi in self.predicateItems)
  50. )
  51.  
  52. def get_placeholder_values(self):
  53. vals = tuple(pi.val for pi in self.predicateItems)
  54. if self.limit:
  55. return vals + (self.limit,)
  56. else:
  57. return vals
  58.  
  59. def set_limit_one(self):
  60. self.limit = self.limitOne
  61.  
  62. def is_limit_one(self):
  63. return self.limit == self.limitOne
  64.  
  65. @staticmethod
  66. def empty(limit=limitAll):
  67. return DatabasePredicate([], limit)
  68.  
  69. @staticmethod
  70. def withTuple(tupleInst: tuple, limit=limitAll):
  71. return DatabasePredicate.withTuples([tupleInst], limit)
  72.  
  73. @staticmethod
  74. def withTuples(tuples: List[tuple], limit=limitAll):
  75. items = [] # type: List[DatabasePredicateItem]
  76. for t in tuples:
  77. items.append(DatabasePredicateItem(t[0], t[1], t[2]))
  78. return DatabasePredicate(items, limit)
  79.  
  80. class DatabaseConnection(object):
  81. def __init__(self, config):
  82. self.__dbconn = DatabaseConnection.__create_database_connection(config)
  83. self.__cursor = None # type: psycopg2.psycopg1.cursor
  84.  
  85. def open_cursor(self):
  86. self.__cursor = self.__dbconn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
  87.  
  88. def close_cursor(self):
  89. self.__cursor.close()
  90.  
  91. # TODO: allow *args (list of tuples) to be passed in to define the predicate
  92. def select_one(self, table_name, predicate: DatabasePredicate):
  93. predicate.set_limit_one()
  94. return self.select(table_name, predicate)
  95.  
  96. def select(self, table_name, predicate: DatabasePredicate):
  97. assert self.__cursor is not None
  98. sqlformat = sql.SQL("select * from {0} where {1}").format(
  99. sql.Identifier(table_name),
  100. predicate.get_sql()
  101. )
  102. query = self.__cursor.mogrify(sqlformat, predicate.get_placeholder_values())
  103. print(query)
  104. self.__cursor.execute(query)
  105. result = self.__cursor.fetchall()
  106. if result and predicate.is_limit_one():
  107. return result[0]
  108. return result
  109.  
  110. def insert(self, table_name, obj):
  111. assert self.__cursor is not None
  112. sqlformat = sql.SQL("insert into {0} ({1}) values ({2}) returning *").format(
  113. sql.Identifier(table_name),
  114. sql.SQL(', ').join(map(sql.Identifier, obj.keys())),
  115. sql.SQL(', ').join(sql.Placeholder() * len(obj.values())))
  116. query = self.__cursor.mogrify(sqlformat, tuple(obj.values()))
  117. print(query)
  118. self.__cursor.execute(query)
  119. result = self.__cursor.fetchone()
  120. return result
  121.  
  122. def update(self, table_name, obj, predicate: DatabasePredicate):
  123. assert self.__cursor is not None
  124. sqlformat = sql.SQL("update {0} set {1} where {2}").format(
  125. sql.Identifier(table_name),
  126. sql.SQL('AND ').join(sql.SQL("{0} = {1}").format(sql.Identifier(k), sql.Placeholder()) for k, v in obj.items()),
  127. predicate.get_sql()
  128. )
  129. query = self.__cursor.mogrify(sqlformat, tuple(obj.values()) + predicate.get_placeholder_values())
  130. print(query)
  131. self.__cursor.execute(query)
  132. return self.select_one(table_name, predicate)
  133.  
  134.  
  135. @staticmethod
  136. def __create_database_connection(config):
  137. connection = psycopg2.connect(host=config.db_host, database=config.db_name, user=config.db_user, password=config.db_pass)
  138. connection.set_session(autocommit=True)
  139. return connection
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement