Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # coding: utf-8
- from __future__ import unicode_literals
- from __future__ import print_function
- from __future__ import division
- from hypothesis.extra.datetime import datetimes
- import hypothesis.strategies as st
- from sqlalchemy import (
- DateTime,
- Integer,
- String,
- )
- from sqlalchemy.dialects.postgresql import JSON
- def mk_generator(mapping_class, column_generators=None):
- """
- :param mapping_class: A class inhering a declarative_base()
- :param column_generators: A optional dict from column names to hypothesis generators to use for that column
- :return: A generator that constructs instances of the mapping class
- """
- if column_generators is None:
- column_generators = {}
- derived_generators = derive_generators(mapping_class, omit_columns=column_generators.keys())
- return st.builds(
- mapping_class,
- **merge_dicts(derived_generators, column_generators)
- )
- def derive_generators(mapping_class, omit_columns):
- columns = _columns_for_mapping_class(mapping_class)
- derive_columns = (c for c in columns if c.name not in omit_columns)
- return {c.name: column_generator(c) for c in derive_columns}
- def column_generator(column):
- gen = None
- if isinstance(column.type, Integer):
- gen = st.integers()
- elif isinstance(column.type, String):
- gen = st.text(max_size=column.type.length)
- elif isinstance(column.type, DateTime):
- allow_naive = not column.type.timezone
- gen = datetimes(allow_naive=allow_naive)
- elif isinstance(column.type, JSON):
- raise ValueError("Unable to derive generator for column {c.name} with type {c.type}: Supply one manually".format(c=column)) # pylint: disable=line-too-long
- if gen is None:
- raise ValueError("Unable to derive generator for column {c.name} with type {c.type}".format(c=column))
- if column.nullable:
- return st.one_of(st.none(), gen)
- else:
- return gen
- def merge_dicts(*dicts):
- merged = {}
- for d in dicts:
- merged.update(d)
- return merged
- def _columns_for_mapping_class(mapping_class):
- table = mapping_class.metadata.tables[mapping_class.__table__.name]
- return table.columns.values()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement