Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from weakref import WeakSet
- from collections import OrderedDict, namedtuple
- from contextlib import closing
- from datetime import date
- from py4j.java_gateway import JavaGateway, JavaObject
- def python_date(java_date):
- """Converts a java.util.Date to a Python datetime.date object."""
- return date.fromtimestamp(java_date.getTime() / 1000) if java_date is not None else None
- class JDBCConnection:
- def __init__(self, gateway: JavaGateway, url: str, username: str, password: str):
- self.url = url
- self.username = username
- self.password = password
- self.gateway = gateway
- self.conn = None
- self.queries = WeakSet()
- def query(self, query: str, params=()) -> 'JDBCQuery':
- """
- Executes a SELECT type SQL query.
- :param query: the SQL to execute
- :param params: parameters for the query
- :return: the resulting query object
- """
- stmt = self.conn.prepareStatement(query)
- for i, param in enumerate(params, 1):
- stmt.setObject(i, param)
- rs = stmt.executeQuery()
- query = JDBCQuery(stmt, rs)
- self.queries.add(query)
- return query
- def execute(self, query: str, params=()):
- """
- Executes an SQL query.
- The query should be a modifying query (INSERT, UPDATE, DELETE).
- For SELECT queries, :meth:`query` should be used instead.
- :param query: the SQL to execute
- :param params: an iterable of same-length tuples
- """
- stmt = self.conn.prepareStatement(query)
- with closing(stmt):
- for param_tuple in params:
- for i, value in enumerate(param_tuple, 1):
- stmt.setObject(i, value)
- stmt.execute()
- def open(self):
- self.conn = self.gateway.jvm.java.sql.DriverManager.getConnection(self.url, self.username,
- self.password)
- def close(self):
- if self.conn and not self.conn.isClosed():
- for query in self.queries:
- query.close()
- self.conn.close()
- self.conn = None
- def __enter__(self):
- self.open()
- return self
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.close()
- def __del__(self):
- self.close()
- Column = namedtuple('Column', ['name', 'converter'])
- class JDBCQuery:
- CONVERTERS = {
- 'date': python_date
- }
- def __init__(self, statement: JavaObject, result_set: JavaObject):
- self.statement = statement
- self.result_set = result_set
- self.columns = []
- metadata = result_set.getMetaData()
- for i in range(metadata.getColumnCount()):
- column_name = metadata.getColumnName(i + 1).lower()
- column_type = metadata.getColumnTypeName(i + 1).lower()
- column = Column(column_name, self.CONVERTERS.get(column_type))
- self.columns.append(column)
- def close(self):
- if self.result_set and not self.result_set.isClosed():
- self.result_set.close()
- self.result_set = None
- if self.statement:
- self.statement.close()
- self.statement = None
- def __iter__(self):
- with closing(self):
- while self.result_set.next():
- row = OrderedDict()
- for i, column in enumerate(self.columns, 1):
- value = self.result_set.getObject(i)
- value = column.converter(value) if column.converter else value
- row[column.name] = value
- yield row
- def __del__(self):
- self.close()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement