Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import json
- from sqlalchemy import insert, select, Column, Text, Integer, String, \
- func, event, create_engine, cast
- from sqlalchemy.dialects import postgresql
- from sqlalchemy.ext.declarative import declarative_base
- from sqlalchemy.exc import ProgrammingError, OperationalError
- from sqlalchemy.sql.expression import literal, BinaryExpression
- from sqlalchemy.types import TypeDecorator, Boolean
- from sqlalchemy.engine import Engine
- from sqlalchemy.sql.operators import custom_op
- from sqlalchemy.ext.compiler import compiles
- import sqlite3
- @compiles(BinaryExpression, 'sqlite')
- def sqlite_extend(element, compiler, **kwargs):
- if not isinstance(element.operator, custom_op):
- return compiler.visit_binary(element)
- if element.operator.opstring == '->':
- acc_func = getattr(func, 'access')
- elif element.operator.opstring == '->>':
- acc_func = getattr(func, 'access_astext')
- else:
- compiler.visit_binary(element)
- acc_func_call = acc_func(element.left, element.right)
- return compiler.process(acc_func_call)
- @event.listens_for(Engine, 'connect')
- def sqlite_engine_connect(dbapi_connection, connection_record):
- if not isinstance(dbapi_connection, sqlite3.Connection):
- return
- def access(d, k):
- return json.loads(d).get(k)
- dbapi_connection.create_function('access', 2, access)
- def access_astext(d, k):
- return str(json.loads(d).get(k))
- dbapi_connection.create_function('access_astext', 2, access_astext)
- class DefaultJSONB(TypeDecorator):
- impl = Text
- def load_dialect_impl(self, dialect):
- if dialect.name == 'postgresql':
- return dialect.type_descriptor(postgresql.JSONB())
- else:
- return dialect.type_descriptor(Text())
- def coerce_compared_value(self, op, value):
- return postgresql.JSONB().coerce_compared_value(op, value)
- class Comparator(postgresql.JSONB.comparator_factory):
- """Mixed comparator_factory."""
- @property
- def astext(self):
- return cast(self.expr, String)
- def __getitem__(self, other):
- return self.op('->')(other)
- comparator_factory = Comparator
- def process_bind_param(self, value, dialect):
- if value is None:
- return value
- if dialect.name == 'postgresql':
- pass
- else:
- value = json.dumps(value)
- return value
- def process_result_value(self, value, dialect):
- if value is None:
- return value
- if dialect.name == 'postgresql':
- pass
- else:
- value = json.loads(value)
- return value
- SQLJSONB = DefaultJSONB
- #SQLJSONB = postgresql.JSONB
- Base = declarative_base()
- ## Postgresql
- #db = create_engine("postgresql://127.0.0.1:5432/test", echo=True)
- ## Sqlite
- db = create_engine("sqlite:////tmp/db.db", echo=True)
- class Users(Base):
- __tablename__ = "users"
- id = Column(Integer, primary_key=True)
- name = Column(String(32))
- infos = Column(SQLJSONB)
- def main():
- try:
- Users.__table__.drop(bind=db)
- except (ProgrammingError, OperationalError):
- pass
- Users.__table__.create(bind=db)
- insrt = insert(Users)
- values_dict = {'name': 'user1', 'infos': {"1": 1, "4": 4, "tab": "stab"}}
- insrt = insrt.values(values_dict)
- db.execute(insrt)
- #slct = select([Users.infos.astext])
- slct = select([Users])
- # Flavour of field match
- ##slct = slct.where(Users.infos["4"] == 4)
- slct = slct.where(Users.infos["4"] == values_dict['infos']["4"])
- slct = slct.where(Users.infos["4"].astext == str(values_dict['infos']["4"]))
- slct = slct.where(Users.infos.op('->>')("4") == str(values_dict['infos']["4"]))
- slct = slct.where(Users.infos.op('->')("tab") == values_dict["infos"]["tab"])
- slct = slct.where(Users.infos.op('->>')("tab") == str(values_dict["infos"]["tab"]))
- # Column match
- slct = slct.where(Users.infos == values_dict['infos'])
- # Expected (workaround)
- #slct = slct.where(Users.infos == json.dumps({"1": 1, "4": 4, "tab": "stab"}))
- row = db.execute(slct)
- return row
- if __name__ == '__main__':
- row = main()
- print
- for e in row:
- print(e)
Add Comment
Please, Sign In to add comment