- ____________________________________
- db_schema.xml
- 1 <?xml version="1.0"?>
- 2 <database>
- 3 <table name="bootstrap" comment="This table is only intended to ever have one record">
- 4 <field name="schemaVersion" type="REAL" />
- 5 <field name="dbSchema" type="VARCHAR(1048576)" comment="1 MB of text" />
- 6 </table>
- 7 <table name="Symbol">
- 8 <field name="Type" type="SMALLINT" />
- 9 <field name="SymbolSpace" type="INTEGER" />
- 10 </table>
- 11 <table name="PrimitiveSymbolToContent">
- 12 <field name="SymbolID" type="INTEGER" ref_target="Symbol(SymbolID)" />
- 13 <field name="Content" type="VARCHAR(1024)" />
- 14 </table>
- 15 <table name="SymbolComposition">
- 16 <field name="CompositeSymbolID" type="INTEGER" ref_target="Symbol(SymbolID)" />
- 17 <field name="ComponentID" type="INTEGER" ref_target="Symbol(SymbolID)" />
- 18 </table>
- 19 <table name="Memory">
- 20 <field name="Starttime" type="timestamp" />
- 21 <field name="Duration" type="timestamp" />
- 22 </table>
- 23 <table name="MemoryFrame">
- 24 <field name="MemoryID" type="INTEGER" ref_target="Memory(MemoryID)" />
- 25 <field name="SequenceIndex" type="INTEGER" />
- 26 <field name="TimeOffset" type="INTERVAL" />
- 27 </table>
- 28 <table name="MemorySymbol">
- 29 <field name="MemoryFrameID" type="INTEGER" ref_target="MemoryFrame(MemoryFrameID)" />
- 30 <field name="SymbolID" type="INTEGER" ref_target="Symbol(SymbolID)" />
- 31 <field name="Activation" type="REAL" />
- 32 </table>
- 33 <table name="Recommendation">
- 34 <field name="RecommenderID" type="INTEGER" ref_target="Symbol(SymbolID)" />
- 35 <field name="RecommendeeID" type="INTEGER" ref_target="Symbol(SymbolID)" />
- 36 <field name="Probability" type="REAL" />
- 37 </table>
- 38 </database>
- ____________________________________
- DBSchema.py
- #!/usr/bin/env python3.1
- # Created: 9:30 PM 2/1/11
- #
- import xml.dom.minidom;
- import os
- import collections
- import sys
- import postgresql
- import utils
- class DBSchema:
- instance = None
- @staticmethod
- def getInstance():
- if DBSchema.instance == None:
- DBSchema.instance = DBSchema()
- return DBSchema.instance
- def __init__(self):
- self.dom = None
- self.tables = []
- self.schema = ''
- self.loadSchema()
- def loadSchema(self):
- f = open('db_schema.xml', 'r')
- self.schema = '\n'.join(f.readlines())
- f.close()
- self.parseSchema(self.schema )
- def parseSchema(self, xmlStr):
- ## Parse the xml and generate a dictionary describing the database
- #print("schema '{0}'".format(xmlStr))
- self.dom = xml.dom.minidom.parseString(xmlStr)
- dbNode = self.dom.getElementsByTagName("database")[0]
- self.tables = []
- for tableNode in dbNode.getElementsByTagName("table"):
- # For each table
- tableAttr = tableNode.attributes
- fields = collections.OrderedDict()
- for field in tableNode.getElementsByTagName("field"):
- # For each column in this table
- fieldAttr = field.attributes
- fieldName = fieldAttr['name'].value
- fields[fieldName] = {}
- fields[fieldName]['type'] = fieldAttr['type'].value
- if "ref_target" in fieldAttr.keys():
- fields[fieldName]['ref_target'] = fieldAttr['ref_target'].value
- if "bootstrap" == tableAttr['name'].value:
- self.schemaVersion = float(5)
- self.tables += [{'table-name':tableAttr['name'].value, 'fields':fields}]
- def dropCmd(self):
- return "DROP TABLE " + ", ".join(map(lambda el: el['table-name'], self.tables))
- def createCmd(self):
- sql = "BEGIN;\n"
- for table in self.tables:
- # For each table
- tableName = table['table-name']
- sql += "\nCREATE TABLE %s\n" % (tableName)
- sql += "(\n"
- sql += "\t{0}ID SERIAL ".format(tableName)
- for fieldName in table['fields'].keys():
- # For each column
- field = table['fields'][fieldName]
- sql += ",\n\t%s %s" % (fieldName, field['type'])
- # if 'ref_target' in field.keys():
- # sql += ' REFERENCES %s' % field['ref_target']
- sql += "\n);"
- sql += "\nCOMMIT;"
- return sql
- def getFieldNamesFromTable(self, tableName):
- return [tableName + 'ID'] + utils.iterToList(utils.iterToList(filter(lambda table: table['table-name'] == tableName,
- self.tables))[0]['fields'].keys())
- def fillBootstrapTable(self):
- sql = "INSERT INTO bootstrap (schemaVersion, dbSchema) VALUES ({0}, {1});".format(self.schemaVersion, self.schema)
- def getInstance():
- return DBSchema.getInstance()
- if __name__ == '__main__':
- dbs = DBSchema()
- dbs.loadSchema()
- if '--create-cmd' in sys.argv:
- print(dbs.createCmd())
- elif '--drop-cmd' in sys.argv:
- print(dbs.dropCmd())
- else:
- print("usage: {0} --create-cmd | --drop-cmd".format(sys.argv[0]))
- ____________________________________
- DBInterface.py
- #!/usr/bin/env python3.1
- #Created: 6:59 PM, 2/2/11
- import postgresql
- import DBSchema
- import utils
- import misc
- class DBInterface:
- def __init__(self):
- self.db_connection = None
- self.setParams('alpha', 'josh', 'secret', 'localhost', '5432')
- def setParams(self, dbname, username, password, server, port):
- self.dbname = dbname
- self.username = username
- self.password = password
- self.server = server
- self.port = port
- def getFields(self):
- self.connect()
- bootstrapFields = ['bootstrapID','schemaVersion','dbSchema']
- sql = "SELECT * FROM bootstrap;"
- result = self.query(sql)
- print("getting fields")
- for row in result:
- i = 0
- for col in row:
- if bootstrapFields[i] == 'dbSchema':
- DBSchema.getInstance().parseSchema(str(col))
- i += 1
- def connect(self):
- if self.db_connection == None:
- self.db_connection = postgresql.open('pq://{0}:{1}@{2}:{3}/{4}'.format(self.username, self.password, self.server, self.port, self.dbname))
- def query(self, queryString):
- if self.db_connection == None:
- raise Exception("Cannot execute query; no database connection")
- return self.db_connection.prepare(queryString)
- def createAccessors(self):
- for fieldName in self.fieldNames:
- # For each field name
- setattr(self, fieldName, None)
- class DBEntryReader(DBInterface):
- """Abstract base class for table entry readers; subclasses will"""
- """be auto-generated from the database schema"""
- def __init__(self, table):
- DBInterface.__init__(self)
- self.table = table
- self.filterList = {}
- self.fieldNames = []
- self.results = None
- self.fieldNames = DBSchema.getInstance().getFieldNamesFromTable(self.table)
- self.createAccessors()
- self.createFilterFunctions()
- def createFilterFunctions(self):
- for fieldName in self.fieldNames:
- # For each field name
- def makeF(fname):
- def f(val):
- if type(val).__name__ == 'list':
- self.filterList[fname] = val
- else:
- self.filterList[fname] = [val]
- return f
- setattr(self, 'FilterBy' + fieldName, makeF(fieldName))
- def Execute(self):
- queryString = "SELECT * FROM %s" % self.table
- filterCount = len(self.filterList)
- if filterCount > 0:
- queryString += " WHERE "
- idx = 0
- for col in self.filterList.keys():
- # For each set of filter criteria
- valList = self.filterList[col]
- queryString += col + ' in (' +\
- ", ".join(utils.iterToList(map(lambda val: str(val), valList)))\
- + ')'
- idx += 1
- if idx < filterCount:
- queryString += " AND "
- queryString += ';'
- self.connect()
- #for r in self.results:
- # i = 0
- # for c in r:
- # setattr(self, self.fieldNames[i], c)
- # i += 1
- print("queryString", queryString)
- self.results = self.query(queryString)
- self.recordGenerator = self.RecordGenerator()
- def RecordGenerator(self):
- try:
- for r in self.results:
- i = 0
- for c in r:
- setattr(self, self.fieldNames[i], c)
- i += 1
- yield True
- yield False
- finally:
- pass
- def Read(self):
- return next(self.recordGenerator)
- def Close(self):
- self.recordGenerator.close()
- self.results.close()
- self.db_connection.close()
- class DBEntryCreator(DBInterface):
- """Abstract base class for table entry creators; subclasses will"""
- """be auto-generated from the database schema"""
- def __init__(self, table):
- DBInterface.__init__(self)
- self.table = table
- self.fieldNames = DBSchema.getInstance().getFieldNamesFromTable(self.table)
- self.createAccessors()
- def Create(self):
- queryString = "INSERT INTO %s " % self.table
- fields = []
- vars = []
- values = []
- i = 1
- for fieldName in self.fieldNames:
- if getattr(self, fieldName) == None:
- print("continuing on field", fieldName)
- continue
- print("fieldName {0} = {1}".format(fieldName, getattr(self, fieldName)))
- fields += [fieldName]
- vars += ['$' + str(i)]
- values += [getattr(self, fieldName)]
- i += 1
- print("fields", fields)
- print("vars", vars)
- print("values", values)
- queryString += "(" + ", ".join(fields) + ")"
- queryString += "VALUES "
- queryString += "(" + ", ".join(vars) + ")"
- queryString += ";"
- self.connect()
- print("query", queryString)
- prep = self.query(queryString)
- results = prep.load_rows([tuple(values)])
- print("Create result", results)
- if __name__ == '__main__':
- dbRdr = DBEntryReader('Symbol')
- dbRdr.FilterByType([0,2])
- dbRdr.FilterBySymbolID([1,2,3])
- dbRdr.Execute()
- while dbRdr.Read():
- print('SymbolID', dbRdr.SymbolID)
- print('Type', dbRdr.Type)
- print('SymbolSpace', dbRdr.SymbolSpace)
- dbRdr.Close()