Advertisement
treyhunner

contract.py

Apr 7th, 2018
359
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.21 KB | None | 0 0
  1. """
  2. Code based on talk by David Beazley's PyCon Israel keynote in 2017
  3.  
  4. Watch the talk at https://www.youtube.com/watch?v=Je8TcRQcUgA
  5.  
  6. Usage::
  7.  
  8.    from contract import Base, PositiveInteger
  9.  
  10.    dx: PositiveInteger
  11.  
  12.    class Player(Base):
  13.        name: AnotherContract
  14.        x: PositiveInteger
  15.        y: PositiveInteger
  16.  
  17.        def left(self, dx):
  18.            self.x -= dx
  19.  
  20.        def right(self, dx):
  21.            self.x += dx
  22.  
  23.    p = Player('Guido', 5, 6)
  24.    p.x = 23
  25.    p.left(5)
  26.    p.left(-5)  # Raises an exception
  27.  
  28. """
  29. from collections import ChainMap
  30. from functools import wraps
  31. from inspect import signature
  32.  
  33.  
  34. _contracts = {}
  35.  
  36. class Contract:
  37.  
  38.     def __init_subclass__(cls):
  39.         _contracts[cls.__name__] = cls
  40.  
  41.     def __set__(self, instance, value):
  42.         self.check(value)
  43.         instance.__dict__[self.name] = value
  44.  
  45.     def __set_name__(self, cls, name):
  46.         self.name = name
  47.  
  48.     @classmethod
  49.     def check(cls, value):
  50.         pass
  51.  
  52.  
  53. class Typed(Contract):
  54.     type = None
  55.  
  56.     @classmethod
  57.     def check(cls, value):
  58.         assert isinstance(value, cls.type), f'Expected {cls.type}'
  59.         super().check(value)
  60.  
  61.  
  62. class Positive(Contract):
  63.     @classmethod
  64.     def check(cls, value):
  65.         assert value > 0, 'Must be > 0'
  66.         super().check(value)
  67.  
  68.  
  69. class Nonempty(Contract):
  70.     @classmethod
  71.     def check(cls, value):
  72.         assert len(value) > 0, 'Must be nonempty'
  73.         super().check(value)
  74.  
  75.  
  76. class Integer(Typed):
  77.     type = int
  78.  
  79.  
  80. class String(Typed):
  81.     type = str
  82.  
  83.  
  84. class NonemptyString(String, Nonempty):
  85.     pass
  86.  
  87.  
  88. class PositiveInteger(Integer, Positive):
  89.     pass
  90.  
  91.  
  92. def checked(func):
  93.     sig = signature(func)
  94.     ann = ChainMap(
  95.         getattr(func, '__annotations__', {}),
  96.         func.__globals__.get('__annotations__', {}),
  97.     )
  98.     @wraps(func)
  99.     def wrapper(*args, **kwargs):
  100.         bound = sig.bind(*args, **kwargs)
  101.         for name, val in bound.arguments.items():
  102.             if name in ann:
  103.                 ann[name].check(val)
  104.         return func(*args, **kwargs)
  105.     return wrapper
  106.  
  107.  
  108. class BaseMeta(type):
  109.     @classmethod
  110.     def __prepare__(cls, *args):
  111.         return ChainMap({}, _contracts, {'George': 4})
  112.  
  113.     def __new__(meta, name, bases, methods):
  114.         methods = methods.maps[0]
  115.         return super().__new__(meta, name, bases, methods)
  116.  
  117.  
  118. class Base(metaclass=BaseMeta):
  119.  
  120.     @classmethod
  121.     def __init_subclass__(cls):
  122.         # Instantiate the contracts
  123.         for name, val in cls.__dict__.items():
  124.             if callable(val):
  125.                 setattr(cls, name, checked(val))
  126.         for name, val in cls.__annotations__.items():
  127.             contract = val()
  128.             contract.__set_name__(cls, name)
  129.             setattr(cls, name, contract)
  130.  
  131.     def __init__(self, *args):
  132.         ann = self.__annotations__
  133.         assert len(args) == len(ann), f'Expected {len(ann)} arguments'
  134.         for name, val in zip(ann, args):
  135.             setattr(self, name, val)
  136.  
  137.     def __repr__(self):
  138.         args = ','.join(
  139.             repr(getattr(self, name))
  140.             for name in self.__annotations__
  141.         )
  142.         return f'{type(self).__name__}({args})'
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement