Advertisement
Guest User

traited_orm.py

a guest
May 10th, 2014
451
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 9.83 KB | None | 0 0
  1. """ Utilities for mapping HasTraits classes to a relational database using
  2. SQLAlchemy.
  3.  
  4. These tools are not declarative, like the Elixir extension. Rather, they just
  5. provide the low-level support for mapping an existing schema to traited classes.
  6. Your classes must subclass from ORMapped. Each mapped trait should have the
  7. "db_storage=True" metadata. Many of the traits have been subclassed here to
  8. provide this by default, e.g. DBInt, DBInstance, DBStr, etc. Many of these are
  9. also customized to accept None, too, in order to support SQL NULLs.
  10.  
  11. The only collection trait supported is DBList. One cannot currently map Dict or
  12. Set traits.
  13.  
  14. Instead of using sqlalchemy.orm.mapper() to declare mappers, use trait_mapper().
  15. For 1:N and M:N relations that map to a DBList, use trait_list_relation()
  16. instead of sqlalchemy.orm.relation().
  17. """
  18.  
  19.  
  20. import weakref
  21.  
  22. from sqlalchemy.orm import EXT_CONTINUE, MapperExtension, attributes, mapper, relation, session
  23.  
  24. from enthought.traits.api import (Any, Array, Either, Float, HasTraits,
  25.     Instance, Int, List, Property, Python, Str, TraitListObject, on_trait_change)
  26.  
  27. __all__ = ['MappedTraitListObject', 'DBList', 'DBAny', 'DBArray', 'DBFloat',
  28.     'DBInstance', 'DBInt', 'DBIntKey', 'DBStr', 'ORMapped',
  29.     'trait_list_relation', 'trait_mapper']
  30.  
  31.  
  32. # A unique object to act as a dummy object for MappedTraitListObjects so we know
  33. # when they have been constructed outside of Traits. It needs to be a valid
  34. # HasTraits instance, but otherwise, nothing special.
  35. HAS_TRAITS_SENTINEL = HasTraits()
  36.  
  37. class MappedTraitListObject(TraitListObject):
  38.     """ TraitListObject decorated for SQLAlchemy relations.
  39.    """
  40.     __emulates__ = list
  41.  
  42.     def __init__(self, *args, **kwds):
  43.         if not args and not kwds:
  44.             args = (DBList(), HAS_TRAITS_SENTINEL, '__fake', [])
  45.         TraitListObject.__init__(self, *args, **kwds)
  46.  
  47. # FIXME: Fix Traits so we don't need this hack.
  48. class WeirdInt(int):
  49.     """ Work around a missing feature in Traits.
  50.  
  51.    Traits uses the default_value_type to determine if a trait is a List, Dict,
  52.    etc. through a dict lookup for deciding if it is going to add the *_items
  53.    events. List subclasses need to use a different default_value_type, though,
  54.    so we'll pretend that we look like a list (default_value_type=5). The other
  55.    place where Traits uses the default_value_type is in the C code, where it
  56.    converts it to a C int, so it will get the real value of "8" there.
  57.  
  58.    Horrible, horrible hack. I am not proud.
  59.    """
  60.     def __hash__(self):
  61.         return hash(5)
  62.     def __eq__(self, other):
  63.         if other == 5:
  64.             return True
  65.         else:
  66.             return int(self) == other
  67.  
  68. class DBList(List):
  69.     """ Subclass of List traits to use SQLAlchemy mapped lists.
  70.    """
  71.  
  72.     default_value_type = WeirdInt(8)
  73.  
  74.     def __init__(self, *args, **kwds):
  75.         kwds['db_storage'] = True
  76.         List.__init__(self, *args, **kwds)
  77.  
  78.         # Set up the Type-8 initializer.
  79.         self.real_default_value = self.default_value
  80.         def type8_init(obj):
  81.             # Handle the conversion to a MappedTraitListObject in the validator.
  82.             return self.real_default_value
  83.         self.default_value = type8_init
  84.  
  85.     def validate(self, object, name, value):
  86.         """ Validates that the values is a valid list.
  87.        """
  88.         if (isinstance(value, list) and
  89.            (self.minlen <= len(value) <= self.maxlen)):
  90.             if object is None:
  91.                 return value
  92.  
  93.             if hasattr(object, '_state'):
  94.                 # Object has been mapped.
  95.                 attr = getattr(object.__class__, name)
  96.                 _, list_obj = attr.impl._build_collection(object._state)
  97.                 # Add back the Traits-specified information.
  98.                 list_obj.__init__(self, object, name, value)
  99.             else:
  100.                 # Object has not been mapped, yet.
  101.                 list_obj = MappedTraitListObject(self, object, name, value)
  102.             return list_obj
  103.  
  104.         self.error(object, name, value)
  105.  
  106. class DBAny(Any):
  107.     def __init__(self, *args, **kwds):
  108.         kwds['db_storage'] = True
  109.         super(DBAny, self).__init__(*args, **kwds)
  110.  
  111. class DBInstance(Instance):
  112.     def __init__(self, *args, **kwds):
  113.         kwds['db_storage'] = True
  114.         super(DBInstance, self).__init__(*args, **kwds)
  115.  
  116. class DBArray(Array):
  117.     def __init__(self, *args, **kwds):
  118.         kwds['db_storage'] = True
  119.         super(DBArray, self).__init__(*args, **kwds)
  120.  
  121. class DBInt(Either):
  122.     def __init__(self, **kwds):
  123.         kwds['db_storage'] = True
  124.         kwds['default'] = 0
  125.         super(DBInt, self).__init__(Int, None, **kwds)
  126.  
  127. class DBIntKey(Either):
  128.     def __init__(self, **kwds):
  129.         kwds['db_storage'] = True
  130.         super(DBIntKey, self).__init__(None, Int, **kwds)
  131.  
  132. class DBUUID(Any):
  133.     def __init__(self, *args, **kwds):
  134.         kwds['db_storage'] = True
  135.         super(DBUUID, self).__init__(*args, **kwds)
  136.  
  137. class DBFloat(Either):
  138.     def __init__(self, **kwds):
  139.         kwds['db_storage'] = True
  140.         kwds['default'] = 0.0
  141.         super(DBFloat, self).__init__(Float, None, **kwds)
  142.  
  143. class DBStr(Either):
  144.     def __init__(self, **kwds):
  145.         kwds['db_storage'] = True
  146.         kwds['default'] = ''
  147.         super(DBStr, self).__init__(Str, None, **kwds)
  148.  
  149. def _fix_dblist(object, value, trait_name, trait):
  150.     """ Fix MappedTraitListObject values for DBList traits that do not have the
  151.    appropriate metadata.
  152.  
  153.    No-op for non-DBList traits, so it may be used indiscriminantly.
  154.    """
  155.     if isinstance(trait.handler, DBList):
  156.         if value.object() is HAS_TRAITS_SENTINEL:
  157.             value.object = weakref.ref(object)
  158.             value.name = trait_name
  159.             value.name_items = trait_name + '_items'
  160.             value.trait = trait.handler
  161.  
  162. class TraitMapperExtension(MapperExtension):
  163.     """ Create ORMapped instances correctly.
  164.    """
  165.  
  166.     def create_instance(self, mapper, selectcontext, row, class_):
  167.         """ Create ORMapped instances correctly.
  168.  
  169.        This will make sure that the HasTraits machinery is hooked up so that
  170.        things like @on_trait_change() will work.
  171.        """
  172.         if issubclass(class_, HasTraits):
  173.             obj = attributes.new_instance(class_)
  174.             HasTraits.__init__(obj)
  175.             return obj
  176.         else:
  177.             return EXT_CONTINUE
  178.  
  179.     def populate_instance(self, mapper, selectcontext, row, instance, **flags):
  180.         """ Receive a newly-created instance before that instance has
  181.        its attributes populated.
  182.  
  183.        This will fix up any MappedTraitListObject values which were created
  184.        without the appropriate metadata.
  185.        """
  186.         if isinstance(instance, HasTraits):
  187.             mapper.populate_instance(selectcontext, instance, row, **flags)
  188.             # Check for bad DBList traits.
  189.             for trait_name, trait in instance.traits(db_storage=True).items():
  190.                 value = instance.trait_get(trait_name)[trait_name]
  191.                 _fix_dblist(instance, value, trait_name, trait)
  192.         else:
  193.             return EXT_CONTINUE
  194.  
  195.  
  196. class ORMapped(HasTraits):
  197.     """ Base class providing the necessary connection to the SQLAlchemy mapper.
  198.    """
  199.  
  200.     # The SQLAlchemy Session this object belongs to.
  201.     _session = Property()
  202.  
  203.     # Any implicit traits added by SQLAlchemy are transient and should not be
  204.     # copied through .clone_traits(), copy.copy(), or pickling.
  205.     _ = Python(transient=True)
  206.  
  207.     def _get__session(self):
  208.         return session.object_session(self)
  209.  
  210.     @on_trait_change('+db_storage')
  211.     def _tell_sqlalchemy(self, object, trait_name, old, new):
  212.         """ If the trait being changed has db_storage metadata, set dirty flag.
  213.  
  214.        Returns
  215.        -------
  216.        If self is linked to a SQLAlchemy session and the conditions have
  217.        been met then the dirty flag on the SQLAlchemy metadata will be set.
  218.  
  219.        Description
  220.        -----------
  221.        HasTrait bypasses the default class attribute getter and setter which
  222.        in turn causes SQLAlchemy to fail to detect that a class has data
  223.        to be flushed.  As a work-around we must manually set the SQLAlchemy
  224.        dirty flag when one of our db_storage traits has been changed.
  225.        """
  226.         if hasattr(self, '_state'):
  227.             trait = self.trait(trait_name)
  228.             # Use the InstrumentedAttribute descriptor on this class inform
  229.             # SQLAlchemy of the changes.
  230.             instr = getattr(self.__class__, trait_name)
  231.             # SQLAlchemy looks at the __dict__ for information. Fool it.
  232.             self.__dict__[trait_name] = old
  233.             _fix_dblist(self, new, trait_name, trait)
  234.             instr.__set__(self, new)
  235.             # The value may have been replaced. Fix it again.
  236.             new = self.trait_get(trait_name)[trait_name]
  237.             _fix_dblist(self, new, trait_name, trait)
  238.             self.__dict__[trait_name] = new
  239.         return
  240.  
  241. def trait_list_relation(argument, secondary=None,
  242.     collection_class=MappedTraitListObject, **kwargs):
  243.     """ An eager relation mapped to a List trait.
  244.  
  245.    The arguments are the same as sqlalchemy.orm.relation().
  246.    """
  247.     kwargs['lazy'] = False
  248.     return relation(argument, secondary=secondary,
  249.         collection_class=collection_class, **kwargs)
  250.  
  251. def trait_mapper(class_, local_table=None, *args, **kwds):
  252.     """ Return a new Mapper object suitably extended to handle HasTraits
  253.    classes.
  254.    """
  255.     # Add our MapperExtension.
  256.     extensions = kwds.setdefault('extension', [])
  257.     if isinstance(extensions, MapperExtension):
  258.         # Scalar. Turn into a list.
  259.         extensions = [extensions]
  260.     extensions.insert(0, TraitMapperExtension())
  261.     kwds['extension'] = extensions
  262.     return mapper(class_, local_table, *args, **kwds)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement