Guest User

Untitled

a guest
May 23rd, 2018
94
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.12 KB | None | 0 0
  1. import json
  2.  
  3. from sqlalchemy import insert, select, Column, Text, Integer, String, \
  4. func, event, create_engine, cast
  5. from sqlalchemy.dialects import postgresql
  6. from sqlalchemy.ext.declarative import declarative_base
  7. from sqlalchemy.exc import ProgrammingError, OperationalError
  8. from sqlalchemy.sql.expression import literal, BinaryExpression
  9. from sqlalchemy.types import TypeDecorator, Boolean
  10. from sqlalchemy.engine import Engine
  11. from sqlalchemy.sql.operators import custom_op
  12. from sqlalchemy.ext.compiler import compiles
  13. import sqlite3
  14.  
  15.  
  16. @compiles(BinaryExpression, 'sqlite')
  17. def sqlite_extend(element, compiler, **kwargs):
  18. if not isinstance(element.operator, custom_op):
  19. return compiler.visit_binary(element)
  20. if element.operator.opstring == '->':
  21. acc_func = getattr(func, 'access')
  22. elif element.operator.opstring == '->>':
  23. acc_func = getattr(func, 'access_astext')
  24. else:
  25. compiler.visit_binary(element)
  26. acc_func_call = acc_func(element.left, element.right)
  27. return compiler.process(acc_func_call)
  28.  
  29. @event.listens_for(Engine, 'connect')
  30. def sqlite_engine_connect(dbapi_connection, connection_record):
  31. if not isinstance(dbapi_connection, sqlite3.Connection):
  32. return
  33. def access(d, k):
  34. return json.loads(d).get(k)
  35. dbapi_connection.create_function('access', 2, access)
  36. def access_astext(d, k):
  37. return str(json.loads(d).get(k))
  38. dbapi_connection.create_function('access_astext', 2, access_astext)
  39.  
  40. class DefaultJSONB(TypeDecorator):
  41.  
  42. impl = Text
  43.  
  44. def load_dialect_impl(self, dialect):
  45. if dialect.name == 'postgresql':
  46. return dialect.type_descriptor(postgresql.JSONB())
  47. else:
  48. return dialect.type_descriptor(Text())
  49.  
  50. def coerce_compared_value(self, op, value):
  51. return postgresql.JSONB().coerce_compared_value(op, value)
  52.  
  53. class Comparator(postgresql.JSONB.comparator_factory):
  54. """Mixed comparator_factory."""
  55. @property
  56. def astext(self):
  57. return cast(self.expr, String)
  58. def __getitem__(self, other):
  59. return self.op('->')(other)
  60. comparator_factory = Comparator
  61.  
  62. def process_bind_param(self, value, dialect):
  63. if value is None:
  64. return value
  65. if dialect.name == 'postgresql':
  66. pass
  67. else:
  68. value = json.dumps(value)
  69. return value
  70.  
  71. def process_result_value(self, value, dialect):
  72. if value is None:
  73. return value
  74. if dialect.name == 'postgresql':
  75. pass
  76. else:
  77. value = json.loads(value)
  78. return value
  79.  
  80. SQLJSONB = DefaultJSONB
  81. #SQLJSONB = postgresql.JSONB
  82. Base = declarative_base()
  83.  
  84. ## Postgresql
  85. #db = create_engine("postgresql://127.0.0.1:5432/test", echo=True)
  86.  
  87. ## Sqlite
  88. db = create_engine("sqlite:////tmp/db.db", echo=True)
  89.  
  90. class Users(Base):
  91. __tablename__ = "users"
  92. id = Column(Integer, primary_key=True)
  93. name = Column(String(32))
  94. infos = Column(SQLJSONB)
  95.  
  96. def main():
  97. try:
  98. Users.__table__.drop(bind=db)
  99. except (ProgrammingError, OperationalError):
  100. pass
  101. Users.__table__.create(bind=db)
  102. insrt = insert(Users)
  103. values_dict = {'name': 'user1', 'infos': {"1": 1, "4": 4, "tab": "stab"}}
  104. insrt = insrt.values(values_dict)
  105. db.execute(insrt)
  106. #slct = select([Users.infos.astext])
  107. slct = select([Users])
  108. # Flavour of field match
  109. ##slct = slct.where(Users.infos["4"] == 4)
  110. slct = slct.where(Users.infos["4"] == values_dict['infos']["4"])
  111. slct = slct.where(Users.infos["4"].astext == str(values_dict['infos']["4"]))
  112. slct = slct.where(Users.infos.op('->>')("4") == str(values_dict['infos']["4"]))
  113. slct = slct.where(Users.infos.op('->')("tab") == values_dict["infos"]["tab"])
  114. slct = slct.where(Users.infos.op('->>')("tab") == str(values_dict["infos"]["tab"]))
  115. # Column match
  116. slct = slct.where(Users.infos == values_dict['infos'])
  117. # Expected (workaround)
  118. #slct = slct.where(Users.infos == json.dumps({"1": 1, "4": 4, "tab": "stab"}))
  119. row = db.execute(slct)
  120. return row
  121.  
  122. if __name__ == '__main__':
  123. row = main()
  124. print
  125. for e in row:
  126. print(e)
Add Comment
Please, Sign In to add comment