Advertisement
Guest User

Untitled

a guest
Mar 31st, 2020
121
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 28.86 KB | None | 0 0
  1. """
  2. Name: Josh Nahum
  3. Time To Completion: 3 hours
  4. Comments:
  5.  
  6. Sources:
  7. """
  8. import string
  9. from operator import itemgetter
  10. from collections import namedtuple
  11. import itertools
  12. from copy import deepcopy
  13.  
  14. _ALL_DATABASES = {}
  15.  
  16.  
  17. WhereClause = namedtuple("WhereClause", ["col_name", "operator", "constant"])
  18. UpdateClause = namedtuple("UpdateClause", ["col_name", "constant"])
  19. FromJoinClause = namedtuple("FromJoinClause", ["left_table_name",
  20.                                                "right_table_name",
  21.                                                "left_join_col_name",
  22.                                                "right_join_col_name"])
  23.  
  24.  
  25. class Connection(object):
  26.  
  27.     def __init__(self, filename, timeout=0.1, isolation_level=None):
  28.         """
  29.        Takes a filename, but doesn't do anything with it.
  30.        (The filename will be used in a future project).
  31.        """
  32.         if filename in _ALL_DATABASES:
  33.             self.database = _ALL_DATABASES[filename]
  34.         else:
  35.             self.database = Database(filename)
  36.             _ALL_DATABASES[filename] = self.database
  37.         self.autocommit  = True
  38.         self.mode = 'DEFERRED'
  39.         self.localdb = deepcopy(self.database)
  40.         # print(id(self))
  41.  
  42.     def execute(self, statement):
  43.         """
  44.        Takes a SQL statement.
  45.        Returns a list of tuples (empty unless select statement
  46.        with rows to return).
  47.        """
  48.         def create_table(tokens):
  49.             """
  50.            Determines the name and column information from tokens add
  51.            has the database create a new table within itself.
  52.            """
  53.             pop_and_check(tokens, "CREATE")
  54.             pop_and_check(tokens, "TABLE")
  55.             assert self.autocommit == True
  56.             if_flag = False
  57.             if_or_table_name = tokens.pop(0)
  58.             if if_or_table_name == 'IF':
  59.                 pop_and_check(tokens, "NOT")
  60.                 pop_and_check(tokens, "EXISTS")
  61.                 if_flag = True
  62.                 table_name = tokens.pop(0)
  63.             else:
  64.                 table_name = if_or_table_name
  65.             if table_name in self.database.tables:
  66.                 if if_flag:
  67.                     return
  68.                 else:
  69.                     raise Exception("Table already exists")
  70.                     return
  71.             pop_and_check(tokens, "(")
  72.             column_name_type_pairs = []
  73.             while True:
  74.                 column_name = tokens.pop(0)
  75.                 qual_col_name = QualifiedColumnName(column_name, table_name)
  76.                 column_type = tokens.pop(0)
  77.                 assert column_type in {"TEXT", "INTEGER", "REAL"}
  78.                 column_name_type_pairs.append((qual_col_name, column_type))
  79.                 comma_or_close = tokens.pop(0)
  80.                 if comma_or_close == ")":
  81.                     break
  82.                 assert comma_or_close == ','
  83.             self.database.create_new_table(table_name, column_name_type_pairs)
  84.            
  85.        
  86.         def drop_table(tokens):
  87.             pop_and_check(tokens, "DROP")
  88.             pop_and_check(tokens, "TABLE")
  89.             assert self.autocommit == True
  90.             if_flag = False
  91.             if_or_table_name = tokens.pop(0)
  92.             if if_or_table_name == 'IF':
  93.                 pop_and_check(tokens, "EXISTS")
  94.                 if_flag = True
  95.                 table_name = tokens.pop(0)
  96.             else:
  97.                 table_name = if_or_table_name
  98.             if table_name not in self.database.tables:
  99.                 if if_flag:
  100.                     return
  101.                 else:
  102.                     raise Exception("Table not exists")
  103.                     return
  104.             else:
  105.                 self.database.lock.acquire_lock('EXCLUSIVE',id(self))
  106.                 self.database.tables.pop(table_name)
  107.                 self.database.lock.release_lock(id(self))
  108.  
  109.  
  110.         def begin(tokens):
  111.             pop_and_check(tokens, "BEGIN")
  112.             mode_or_transaction = tokens.pop(0)
  113.             if mode_or_transaction == 'DEFERRED':
  114.                 self.mode = 'DEFERRED'
  115.                 pop_and_check(tokens, "TRANSACTION")
  116.             elif mode_or_transaction == 'IMMEDIATE':
  117.                 self.mode = 'IMMEDIATE'
  118.                 self.database.lock.acquire_lock('RESERVED', id(self))
  119.                 pop_and_check(tokens, "TRANSACTION")
  120.             elif mode_or_transaction == 'EXCLUSIVE':
  121.                 self.mode = 'EXCLUSIVE'
  122.                 self.database.lock.acquire_lock('EXCLUSIVE', id(self))
  123.                 pop_and_check(tokens, "TRANSACTION")
  124.             if self.autocommit == True:
  125.                 self.localdb = deepcopy(self.database)
  126.                 self.autocommit  = False
  127.             else:
  128.                 raise Exception("Transaction already begin")
  129.            
  130.            
  131.         def commit(tokens):
  132.             # print(id(self))
  133.             # print(self.database.lock.locking_state)
  134.             pop_and_check(tokens, "COMMIT")
  135.             pop_and_check(tokens, "TRANSACTION")
  136.             if self.autocommit == False:
  137.                 if self.database.lock.locking_state[2] == id(self):
  138.                     self.database.lock.acquire_lock('EXCLUSIVE', id(self))
  139.                 self.database.tables = deepcopy(self.localdb.tables)
  140.                 self.autocommit = True
  141.                
  142.                 self.database.lock.release_lock(id(self))
  143.                 # print(self.database.lock.locking_state)
  144.             else:
  145.                 raise Exception("Transaction can not commit in autocommit mode")
  146.        
  147.         def rollback(tokens):
  148.             pop_and_check(tokens, "ROLLBACK")
  149.             pop_and_check(tokens, "TRANSACTION")
  150.             if self.autocommit == False:
  151.                 self.database.lock.release_lock(id(self))
  152.                 self.autocommit = True
  153.             else:
  154.                 raise Exception("Transaction can not rollback in autocommit mode")
  155.        
  156.  
  157.        
  158.         def insert(tokens):
  159.             """
  160.            Determines the table name and row values to add.
  161.            """
  162.             # print(self.database.lock.locking_state)
  163.             if self.mode == 'DEFERRED' or self.autocommit:
  164.                 self.database.lock.acquire_lock('RESERVED', id(self))
  165.  
  166.             def get_comma_seperated_contents(tokens):
  167.                 contents = []
  168.                 pop_and_check(tokens, "(")
  169.                 while True:
  170.                     item = tokens.pop(0)
  171.                     contents.append(item)
  172.                     comma_or_close = tokens.pop(0)
  173.                     if comma_or_close == ")":
  174.                         return contents
  175.                     assert comma_or_close == ',', comma_or_close
  176.  
  177.             pop_and_check(tokens, "INSERT")
  178.             pop_and_check(tokens, "INTO")
  179.             table_name = tokens.pop(0)
  180.             if tokens[0] == "(":
  181.                 col_names = get_comma_seperated_contents(tokens)
  182.                 qual_col_names = [QualifiedColumnName(col_name, table_name)
  183.                                   for col_name in col_names]
  184.             else:
  185.                 qual_col_names = None
  186.             pop_and_check(tokens, "VALUES")
  187.             while tokens:
  188.                 row_contents = get_comma_seperated_contents(tokens)
  189.                 if qual_col_names:
  190.                     assert len(row_contents) == len(qual_col_names)
  191.                 if self.autocommit:
  192.                     self.database.insert_into(table_name,
  193.                                             row_contents,
  194.                                             qual_col_names=qual_col_names)
  195.                     self.database.lock.release_lock(id(self))
  196.                 else:
  197.                     self.localdb.insert_into(table_name,
  198.                                             row_contents,
  199.                                             qual_col_names=qual_col_names)
  200.                 if tokens:
  201.                     pop_and_check(tokens, ",")
  202.  
  203.         def get_qualified_column_name(tokens):
  204.             """
  205.            Returns comsumes tokens to  generate tuples to create
  206.            a QualifiedColumnName.
  207.            """
  208.             possible_col_name = tokens.pop(0)
  209.             if tokens and tokens[0] == '.':
  210.                 tokens.pop(0)
  211.                 actual_col_name = tokens.pop(0)
  212.                 table_name = possible_col_name
  213.                 return QualifiedColumnName(actual_col_name, table_name)
  214.             return QualifiedColumnName(possible_col_name)
  215.  
  216.         def update(tokens):
  217.             if self.mode == 'DEFERRED' or self.autocommit:
  218.                 self.database.lock.acquire_lock('RESERVED', id(self))
  219.             pop_and_check(tokens, "UPDATE")
  220.             table_name = tokens.pop(0)
  221.             pop_and_check(tokens, "SET")
  222.             update_clauses = []
  223.             while tokens:
  224.                 qual_name = get_qualified_column_name(tokens)
  225.                 if not qual_name.table_name:
  226.                     qual_name.table_name = table_name
  227.                 pop_and_check(tokens, '=')
  228.                 constant = tokens.pop(0)
  229.                 update_clause = UpdateClause(qual_name, constant)
  230.                 update_clauses.append(update_clause)
  231.                 if tokens:
  232.                     if tokens[0] == ',':
  233.                         tokens.pop(0)
  234.                         continue
  235.                     elif tokens[0] == "WHERE":
  236.                         break
  237.  
  238.             where_clause = get_where_clause(tokens, table_name)
  239.             if self.autocommit:
  240.                 self.database.update(table_name, update_clauses, where_clause)
  241.                 self.database.lock.release_lock(id(self))
  242.             else:
  243.                 self.localdb.update(table_name, update_clauses, where_clause)
  244.  
  245.         def delete(tokens):
  246.             if self.mode == 'DEFERRED' or self.autocommit:
  247.                 self.database.lock.acquire_lock('RESERVED', id(self))
  248.             pop_and_check(tokens, "DELETE")
  249.             pop_and_check(tokens, "FROM")
  250.             table_name = tokens.pop(0)
  251.             where_clause = get_where_clause(tokens, table_name)
  252.             if self.autocommit:
  253.                 self.database.delete(table_name, where_clause)
  254.                 self.database.lock.release_lock(id(self))
  255.             else:
  256.                 self.localdb.delete(table_name, where_clause)
  257.  
  258.         def get_where_clause(tokens, table_name):
  259.             if not tokens or tokens[0] != "WHERE":
  260.                 return None
  261.             tokens.pop(0)
  262.             qual_col_name = get_qualified_column_name(tokens)
  263.             if not qual_col_name.table_name:
  264.                 qual_col_name.table_name = table_name
  265.             operators = {">", "<", "=", "!=", "IS"}
  266.             found_operator = tokens.pop(0)
  267.             assert found_operator in operators
  268.             if tokens[0] == "NOT":
  269.                 tokens.pop(0)
  270.                 found_operator += " NOT"
  271.             constant = tokens.pop(0)
  272.             if constant is None:
  273.                 assert found_operator in {"IS", "IS NOT"}
  274.             if found_operator in {"IS", "IS NOT"}:
  275.                 assert constant is None
  276.             return WhereClause(qual_col_name, found_operator, constant)
  277.  
  278.         def select(tokens):
  279.             """
  280.            Determines the table name, output_columns, and order_by_columns.
  281.            """
  282.             if self.mode == 'DEFERRED' or self.autocommit:
  283.                 # print(id(self))
  284.                 # print(self.database.lock.locking_state)
  285.                 self.database.lock.acquire_lock('SHARED', id(self))
  286.                
  287.             def get_from_join_clause(tokens):
  288.                 left_table_name = tokens.pop(0)
  289.                 if tokens[0] != "LEFT":
  290.                     return FromJoinClause(left_table_name, None, None, None)
  291.                 pop_and_check(tokens, "LEFT")
  292.                 pop_and_check(tokens, "OUTER")
  293.                 pop_and_check(tokens, "JOIN")
  294.                 right_table_name = tokens.pop(0)
  295.                 pop_and_check(tokens, "ON")
  296.                 left_col_name = get_qualified_column_name(tokens)
  297.                 pop_and_check(tokens, "=")
  298.                 right_col_name = get_qualified_column_name(tokens)
  299.                 return FromJoinClause(left_table_name,
  300.                                       right_table_name,
  301.                                       left_col_name,
  302.                                       right_col_name)
  303.  
  304.             pop_and_check(tokens, "SELECT")
  305.  
  306.             is_distinct = tokens[0] == "DISTINCT"
  307.             if is_distinct:
  308.                 tokens.pop(0)
  309.  
  310.             output_columns = []
  311.             while True:
  312.                 qual_col_name = get_qualified_column_name(tokens)
  313.                 output_columns.append(qual_col_name)
  314.                 comma_or_from = tokens.pop(0)
  315.                 if comma_or_from == "FROM":
  316.                     break
  317.                 assert comma_or_from == ','
  318.  
  319.             # FROM or JOIN
  320.             from_join_clause = get_from_join_clause(tokens)
  321.             table_name = from_join_clause.left_table_name
  322.  
  323.             # WHERE
  324.             where_clause = get_where_clause(tokens, table_name)
  325.  
  326.             # ORDER BY
  327.             pop_and_check(tokens, "ORDER")
  328.             pop_and_check(tokens, "BY")
  329.             order_by_columns = []
  330.             while True:
  331.                 qual_col_name = get_qualified_column_name(tokens)
  332.                 order_by_columns.append(qual_col_name)
  333.                 if not tokens:
  334.                     break
  335.                 pop_and_check(tokens, ",")
  336.             if self.autocommit:
  337.                 self.database.lock.release_lock(id(self))
  338.                 return self.database.select(
  339.                     output_columns,
  340.                     order_by_columns,
  341.                     from_join_clause=from_join_clause,
  342.                     where_clause=where_clause,
  343.                     is_distinct=is_distinct)
  344.             else:
  345.                 return self.localdb.select(
  346.                     output_columns,
  347.                     order_by_columns,
  348.                     from_join_clause=from_join_clause,
  349.                     where_clause=where_clause,
  350.                     is_distinct=is_distinct)
  351.  
  352.         tokens = tokenize(statement)
  353.         last_semicolon = tokens.pop()
  354.         assert last_semicolon == ";"
  355.  
  356.         if tokens[0] == "CREATE":
  357.             create_table(tokens)
  358.             return []
  359.         elif tokens[0] == "DROP":
  360.             drop_table(tokens)
  361.             return []
  362.         elif tokens[0] == "INSERT":
  363.             insert(tokens)
  364.             return []
  365.         elif tokens[0] == "UPDATE":
  366.             update(tokens)
  367.             return []
  368.         elif tokens[0] == "DELETE":
  369.             delete(tokens)
  370.             return []
  371.         elif tokens[0] == "SELECT":
  372.             return select(tokens)
  373.         elif tokens[0] == "BEGIN":
  374.             begin(tokens)
  375.             return []
  376.         elif tokens[0] == "COMMIT":
  377.             commit(tokens)
  378.             return[]
  379.         elif tokens[0] == "ROLLBACK":
  380.             rollback(tokens)
  381.             return []
  382.         else:
  383.             raise AssertionError(
  384.                 "Unexpected first word in statements: " + tokens[0])
  385.  
  386.     def close(self):
  387.         """
  388.        Empty method that will be used in future projects
  389.        """
  390.         pass
  391.  
  392.  
  393. def connect(filename, timeout=0.1, isolation_level=None):
  394.     """
  395.    Creates a Connection object with the given filename
  396.    """
  397.     return Connection(filename, timeout, isolation_level)
  398.  
  399.  
  400. class QualifiedColumnName:
  401.  
  402.     def __init__(self, col_name, table_name=None):
  403.         self.col_name = col_name
  404.         self.table_name = table_name
  405.  
  406.     def __str__(self):
  407.         return "QualifiedName({}.{})".format(
  408.             self.table_name, self.col_name)
  409.  
  410.     def __eq__(self, other):
  411.         same_col = self.col_name == other.col_name
  412.         if not same_col:
  413.             return False
  414.         both_have_tables = (self.table_name is not None and
  415.                             other.col_name is not None)
  416.         if not both_have_tables:
  417.             return True
  418.         return self.table_name == other.table_name
  419.  
  420.     def __ne__(self, other):
  421.         return not (self == other)
  422.  
  423.     def __hash__(self):
  424.         return hash((self.col_name, self.table_name))
  425.  
  426.     def __repr__(self):
  427.         return str(self)
  428.  
  429.  
  430. class Database:
  431.  
  432.     def __init__(self, filename):
  433.         self.filename = filename
  434.         self.tables = {}
  435.         self.lock = LockTable()
  436.  
  437.     def create_new_table(self, table_name, column_name_type_pairs):
  438.         assert table_name not in self.tables
  439.         self.tables[table_name] = Table(table_name, column_name_type_pairs)
  440.         return []
  441.  
  442.     def insert_into(self, table_name, row_contents, qual_col_names=None):
  443.         assert table_name in self.tables
  444.         table = self.tables[table_name]
  445.         table.insert_new_row(row_contents, qual_col_names=qual_col_names)
  446.         return []
  447.  
  448.     def update(self, table_name, update_clauses, where_clause):
  449.         assert table_name in self.tables
  450.         table = self.tables[table_name]
  451.         table.update(update_clauses, where_clause)
  452.  
  453.     def delete(self, table_name, where_clause):
  454.         assert table_name in self.tables
  455.         table = self.tables[table_name]
  456.         table.delete(where_clause)
  457.  
  458.     def select(self, output_columns, order_by_columns,
  459.                from_join_clause,
  460.                where_clause=None, is_distinct=False):
  461.         assert from_join_clause.left_table_name in self.tables
  462.         if from_join_clause.right_table_name:
  463.             assert from_join_clause.right_table_name in self.tables
  464.             left_table = self.tables[from_join_clause.left_table_name]
  465.             right_table = self.tables[from_join_clause.right_table_name]
  466.             all_columns = itertools.chain(
  467.                 zip(left_table.column_names, left_table.column_types),
  468.                 zip(right_table.column_names, right_table.column_types))
  469.             left_col = from_join_clause.left_join_col_name
  470.             right_col = from_join_clause.right_join_col_name
  471.             join_table = Table("", all_columns)
  472.             combined_rows = []
  473.             for left_row in left_table.rows:
  474.                 left_value = left_row[left_col]
  475.                 found_match = False
  476.                 for right_row in right_table.rows:
  477.                     right_value = right_row[right_col]
  478.                     if left_value is None:
  479.                         break
  480.                     if right_value is None:
  481.                         continue
  482.                     if left_row[left_col] == right_row[right_col]:
  483.                         new_row = dict(left_row)
  484.                         new_row.update(right_row)
  485.                         combined_rows.append(new_row)
  486.                         found_match = True
  487.                         continue
  488.                 if left_value is None or not found_match:
  489.                     new_row = dict(left_row)
  490.                     new_row.update(zip(right_row.keys(),
  491.                                        itertools.repeat(None)))
  492.                     combined_rows.append(new_row)
  493.  
  494.             join_table.rows = combined_rows
  495.             table = join_table
  496.         else:
  497.             table = self.tables[from_join_clause.left_table_name]
  498.         return table.select_rows(output_columns, order_by_columns,
  499.                                  where_clause=where_clause,
  500.                                  is_distinct=is_distinct)
  501.    
  502.  
  503.  
  504. class Table:
  505.  
  506.     def __init__(self, name, column_name_type_pairs):
  507.         self.name = name
  508.         self.column_names, self.column_types = zip(*column_name_type_pairs)
  509.         self.rows = []
  510.  
  511.     def insert_new_row(self, row_contents, qual_col_names=None):
  512.         if not qual_col_names:
  513.             qual_col_names = self.column_names
  514.         assert len(qual_col_names) == len(row_contents)
  515.         row = dict(zip(qual_col_names, row_contents))
  516.         for null_default_col in set(self.column_names) - set(qual_col_names):
  517.             row[null_default_col] = None
  518.         self.rows.append(row)
  519.  
  520.     def update(self, update_clauses, where_clause):
  521.         for row in self.rows:
  522.             if self._row_match_where(row, where_clause):
  523.                 for update_clause in update_clauses:
  524.                     row[update_clause.col_name] = update_clause.constant
  525.  
  526.     def delete(self, where_clause):
  527.         self.rows = [row for row in self.rows
  528.                      if not self._row_match_where(row, where_clause)]
  529.  
  530.     def _row_match_where(self, row, where_clause):
  531.         if not where_clause:
  532.             return True
  533.         new_rows = []
  534.         value = row[where_clause.col_name]
  535.  
  536.         op = where_clause.operator
  537.         cons = where_clause.constant
  538.         if ((op == "IS NOT" and (value is not cons)) or
  539.                 (op == "IS" and value is cons)):
  540.             return True
  541.  
  542.         if value is None:
  543.             return False
  544.  
  545.         if ((op == ">" and value > cons) or
  546.             (op == "<" and value < cons) or
  547.             (op == "=" and value == cons) or
  548.                 (op == "!=" and value != cons)):
  549.             return True
  550.         return False
  551.  
  552.     def select_rows(self, output_columns, order_by_columns,
  553.                     where_clause=None, is_distinct=False):
  554.         def expand_star_column(output_columns):
  555.             new_output_columns = []
  556.             for col in output_columns:
  557.                 if col.col_name == "*":
  558.                     new_output_columns.extend(self.column_names)
  559.                 else:
  560.                     new_output_columns.append(col)
  561.             return new_output_columns
  562.  
  563.         def check_columns_exist(columns):
  564.             assert all(col in self.column_names
  565.                        for col in columns)
  566.  
  567.         def ensure_fully_qualified(columns):
  568.             for col in columns:
  569.                 if col.table_name is None:
  570.                     col.table_name = self.name
  571.  
  572.         def sort_rows(rows, order_by_columns):
  573.             return sorted(rows, key=itemgetter(*order_by_columns))
  574.  
  575.         def generate_tuples(rows, output_columns):
  576.             for row in rows:
  577.                 yield tuple(row[col] for col in output_columns)
  578.  
  579.         def remove_duplicates(tuples):
  580.             seen = set()
  581.             uniques = []
  582.             for row in tuples:
  583.                 if row in seen:
  584.                     continue
  585.                 seen.add(row)
  586.                 uniques.append(row)
  587.             return uniques
  588.  
  589.         expanded_output_columns = expand_star_column(output_columns)
  590.  
  591.         check_columns_exist(expanded_output_columns)
  592.         ensure_fully_qualified(expanded_output_columns)
  593.         check_columns_exist(order_by_columns)
  594.         ensure_fully_qualified(order_by_columns)
  595.  
  596.         filtered_rows = [row for row in self.rows
  597.                          if self._row_match_where(row, where_clause)]
  598.         sorted_rows = sort_rows(filtered_rows, order_by_columns)
  599.  
  600.         list_of_tuples = generate_tuples(sorted_rows, expanded_output_columns)
  601.         if is_distinct:
  602.             return remove_duplicates(list_of_tuples)
  603.         return list_of_tuples
  604.  
  605. class LockTable:
  606.    
  607.     def __init__(self):
  608.         # self.LOCK_TO_LEVEL = {'UNLOCKED':0, 'SHARED':1, 'RESERVED':2, 'EXCLUSIVE':3}
  609.         self.locking_state = {0:None, 1:set(), 2:None, 3:None}
  610.         self.current_lock = 0
  611.  
  612.     def acquire_lock(self, lock_type, transaction_id):
  613.         if lock_type == 'EXCLUSIVE':
  614.             if self.current_lock == 3:
  615.                 if self.locking_state[3] != transaction_id:
  616.                     raise Exception("Can not acquire E-lock")
  617.             elif self.current_lock == 2:
  618.                 if self.locking_state[2] != transaction_id:
  619.                     raise Exception("Can not acquire E-lock")
  620.                 else:
  621.                     if self.locking_state[1]:
  622.                         raise Exception("Can not acquire E-lock")
  623.                     else:
  624.                         self.locking_state[2] = None
  625.                         self.locking_state[3] = transaction_id
  626.                         self.current_lock = 3
  627.             elif self.current_lock == 1:
  628.                 self.locking_state[1].discard(transaction_id)
  629.                 if self.locking_state[1]:
  630.                     raise Exception("Can not acquire E-lock")
  631.                 else:
  632.                     self.locking_state[3] = transaction_id
  633.                     self.current_lock = 3
  634.             else:
  635.                 self.locking_state[3] = transaction_id
  636.                 self.current_lock = 3
  637.         elif lock_type == 'RESERVED':
  638.             if self.current_lock == 3:
  639.                 if self.locking_state[3] != transaction_id:
  640.                     raise Exception("Can not acquire R-lock")
  641.             elif self.current_lock == 2:
  642.                 if self.locking_state[2] != transaction_id:
  643.                     raise Exception("Can not acquire R-lock")
  644.             elif self.current_lock == 1:
  645.                 self.locking_state[1].discard(transaction_id)
  646.                 self.locking_state[2] = transaction_id
  647.                 self.current_lock = 2
  648.             else:
  649.                 self.locking_state[2] = transaction_id
  650.                 self.current_lock = 2
  651.         elif lock_type == 'SHARED':
  652.             if self.current_lock == 3:
  653.                 if self.locking_state[3] != transaction_id:
  654.                     raise Exception("Can not acquire S-lock")
  655.             elif self.current_lock == 2:
  656.                 if self.locking_state[2] != transaction_id:
  657.                     self.locking_state[1].add(transaction_id)
  658.             elif self.current_lock == 1:
  659.                 self.locking_state[1].add(transaction_id)
  660.             else:
  661.                 self.locking_state[1].add(transaction_id)
  662.                 self.current_lock = 1
  663.    
  664.     def release_lock(self, transaction_id):
  665.         if self.locking_state[3] == transaction_id:
  666.             self.locking_state[3] = None
  667.         elif self.locking_state[2] == transaction_id:
  668.             self.locking_state[2] = None
  669.         self.locking_state[1].discard(transaction_id)
  670.         self.current_lock = 0
  671.         for i in range(3, 0, -1):
  672.             if self.locking_state[i]:
  673.                 self.current_lock = i
  674.                 break
  675.  
  676.  
  677.  
  678. def pop_and_check(tokens, same_as):
  679.     item = tokens.pop(0)
  680.     assert item == same_as, "{} != {}".format(item, same_as)
  681.  
  682.  
  683. def collect_characters(query, allowed_characters):
  684.     letters = []
  685.     for letter in query:
  686.         if letter not in allowed_characters:
  687.             break
  688.         letters.append(letter)
  689.     return "".join(letters)
  690.  
  691.  
  692. def remove_leading_whitespace(query, tokens):
  693.     whitespace = collect_characters(query, string.whitespace)
  694.     return query[len(whitespace):]
  695.  
  696.  
  697. def remove_word(query, tokens):
  698.     word = collect_characters(query,
  699.                               string.ascii_letters + "_" + string.digits)
  700.     if word == "NULL":
  701.         tokens.append(None)
  702.     else:
  703.         tokens.append(word)
  704.     return query[len(word):]
  705.  
  706.  
  707. def remove_text(query, tokens):
  708.     if (query[0] == "'"):
  709.         delimiter = "'"
  710.     else:
  711.         delimiter = '"'
  712.     query = query[1:]
  713.     end_quote_index = query.find(delimiter)
  714.     while query[end_quote_index + 1] == delimiter:
  715.         # Remove Escaped Quote
  716.         query = query[:end_quote_index] + query[end_quote_index + 1:]
  717.         end_quote_index = query.find(delimiter, end_quote_index + 1)
  718.     text = query[:end_quote_index]
  719.     tokens.append(text)
  720.     query = query[end_quote_index + 1:]
  721.     return query
  722.  
  723.  
  724. def remove_integer(query, tokens):
  725.     int_str = collect_characters(query, string.digits)
  726.     tokens.append(int_str)
  727.     return query[len(int_str):]
  728.  
  729.  
  730. def remove_number(query, tokens):
  731.     query = remove_integer(query, tokens)
  732.     if query[0] == ".":
  733.         whole_str = tokens.pop()
  734.         query = query[1:]
  735.         query = remove_integer(query, tokens)
  736.         frac_str = tokens.pop()
  737.         float_str = whole_str + "." + frac_str
  738.         tokens.append(float(float_str))
  739.     else:
  740.         int_str = tokens.pop()
  741.         tokens.append(int(int_str))
  742.     return query
  743.  
  744.  
  745. def tokenize(query):
  746.     tokens = []
  747.     while query:
  748.         old_query = query
  749.  
  750.         if query[0] in string.whitespace:
  751.             query = remove_leading_whitespace(query, tokens)
  752.             continue
  753.  
  754.         if query[0] in (string.ascii_letters + "_"):
  755.             query = remove_word(query, tokens)
  756.             continue
  757.  
  758.         if query[:2] == "!=":
  759.             tokens.append(query[:2])
  760.             query = query[2:]
  761.             continue
  762.  
  763.         if query[0] in "(),;*.><=":
  764.             tokens.append(query[0])
  765.             query = query[1:]
  766.             continue
  767.  
  768.         if query[0] in {"'", '"'}:
  769.             query = remove_text(query, tokens)
  770.             continue
  771.  
  772.         if query[0] in string.digits:
  773.             query = remove_number(query, tokens)
  774.             continue
  775.  
  776.         if len(query) == len(old_query):
  777.             raise AssertionError(
  778.                 "Query didn't get shorter. query = {}".format(query))
  779.  
  780.     return tokens
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement