Advertisement
Guest User

Untitled

a guest
Aug 18th, 2016
79
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 15.00 KB | None | 0 0
  1. import sys
  2. import os
  3. import logging
  4.  
  5. from pony.py23compat import PY2
  6. from ponytest import with_cli_args, pony_fixtures, ValidationError, Fixture, provider_validators, provider
  7.  
  8. from functools import wraps, partial
  9. import click
  10. from contextlib import contextmanager, closing
  11.  
  12.  
  13. from pony.orm.dbproviders.mysql import mysql_module
  14. from pony.utils import cached_property, class_property
  15.  
  16. if not PY2:
  17. from contextlib import contextmanager, ContextDecorator, ExitStack
  18. else:
  19. from contextlib2 import contextmanager, ContextDecorator, ExitStack
  20.  
  21. import unittest
  22.  
  23. from pony.orm import db_session, Database, rollback, delete
  24.  
  25. if not PY2:
  26. from io import StringIO
  27. else:
  28. from StringIO import StringIO
  29.  
  30. from multiprocessing import Process
  31.  
  32. import threading
  33.  
  34.  
  35. class DBContext(ContextDecorator):
  36.  
  37. __fixture__ = 'db'
  38.  
  39. enabled = False
  40.  
  41. def __init__(self, Test):
  42. if not isinstance(Test, type):
  43. TestCls = type(Test)
  44. NewClass = type(TestCls.__name__, (TestCls,), {})
  45. NewClass.__module__ = TestCls.__module__
  46. NewClass.db = property(lambda t: self.db)
  47. Test.__class__ = NewClass
  48. else:
  49. Test.db = class_property(lambda cls: self.db)
  50. Test.db_provider = self.provider
  51. self.Test = Test
  52.  
  53. @class_property
  54. def fixture_name(cls):
  55. return cls.provider
  56.  
  57. @class_property
  58. def provider(cls):
  59. # is used in tests
  60. return cls.PROVIDER
  61.  
  62. def init_db(self):
  63. raise NotImplementedError
  64.  
  65. @cached_property
  66. def db(self):
  67. raise NotImplementedError
  68.  
  69. def __enter__(self):
  70. self.init_db()
  71. try:
  72. self.Test.make_entities()
  73. except (AttributeError, TypeError):
  74. # No method make_entities with due signature
  75. pass
  76. else:
  77. self.db.generate_mapping(check_tables=True, create_tables=True)
  78. return self.db
  79.  
  80. def __exit__(self, *exc_info):
  81. self.db.provider.disconnect()
  82.  
  83. # @classmethod
  84. # @with_cli_args
  85. # @click.option('--db', '-d', 'database', multiple=True)
  86. # @click.option('--exclude-db', '-e', multiple=True)
  87. # def cli(cls, database, exclude_db):
  88. # fixture = [
  89. # MySqlContext, OracleContext, SqliteContext, PostgresContext,
  90. # SqlServerContext,
  91. # ]
  92. # all_db = [ctx.provider for ctx in fixture]
  93. # for db in database:
  94. # if db == 'all':
  95. # continue
  96. # assert db in all_db, (
  97. # "Unknown provider: %s. Use one of %s." % (db, ', '.join(all_db))
  98. # )
  99. # if 'all' in database:
  100. # database = all_db
  101. # elif exclude_db and not database:
  102. # database = set(all_db) - set(exclude_db)
  103. # elif not database:
  104. # database = ['sqlite']
  105. # for Ctx in fixture:
  106. # if Ctx.provider in database:
  107. # yield Ctx
  108.  
  109. db_name = 'testdb'
  110.  
  111. # class DbFixture(Fixture):
  112. # __key__ = 'db'
  113.  
  114.  
  115. # class GenerateMapping(Fixture):
  116. # __key__ = 'generate_mapping'
  117.  
  118.  
  119. @provider()
  120. class GenerateMapping(ContextDecorator):
  121.  
  122. weight = 200
  123. scope = 'class'
  124. __fixture__ = 'generate_mapping'
  125.  
  126. def __init__(self, Test):
  127. self.Test = Test
  128.  
  129. def __enter__(self):
  130. db = getattr(self.Test, 'db', None)
  131. if not db or not db.entities:
  132. return
  133. for entity in db.entities.values():
  134. if entity._database_.schema is None:
  135. db.generate_mapping(check_tables=True, create_tables=True)
  136. break
  137.  
  138. def __exit__(self, *exc_info):
  139. pass
  140.  
  141. @provider()
  142. class MySqlContext(DBContext):
  143. PROVIDER = 'mysql'
  144.  
  145. def drop_db(self, cursor):
  146. cursor.execute('use sys')
  147. cursor.execute('drop database %s' % self.db_name)
  148.  
  149.  
  150. def init_db(self):
  151. with closing(mysql_module.connect(**self.CONN).cursor()) as c:
  152. try:
  153. self.drop_db(c)
  154. except mysql_module.DatabaseError as exc:
  155. print('Failed to drop db: %s' % exc)
  156. c.execute('create database %s' % self.db_name)
  157. c.execute('use %s' % self.db_name)
  158.  
  159. CONN = {
  160. 'host': "localhost",
  161. 'user': "ponytest",
  162. 'passwd': "ponytest",
  163. }
  164.  
  165. @cached_property
  166. def db(self):
  167. CONN = dict(self.CONN, db=self.db_name)
  168. return Database('mysql', **CONN)
  169.  
  170. @provider()
  171. class SqlServerContext(DBContext):
  172.  
  173. PROVIDER = 'sqlserver'
  174.  
  175. def get_conn_string(self, db=None):
  176. s = (
  177. 'DSN=MSSQLdb;'
  178. 'SERVER=mssql;'
  179. 'UID=sa;'
  180. 'PWD=pass;'
  181. )
  182. if db:
  183. s += 'DATABASE=%s' % db
  184. return s
  185.  
  186. @cached_property
  187. def db(self):
  188. CONN = self.get_conn_string(self.db_name)
  189. return Database('mssqlserver', CONN)
  190.  
  191. def init_db(self):
  192. import pyodbc
  193. cursor = pyodbc.connect(self.get_conn_string(), autocommit=True).cursor()
  194. with closing(cursor) as c:
  195. try:
  196. self.drop_db(c)
  197. except pyodbc.DatabaseError as exc:
  198. print('Failed to drop db: %s' % exc)
  199. c.execute('create database %s' % self.db_name)
  200. c.execute('use %s' % self.db_name)
  201.  
  202. def drop_db(self, cursor):
  203. cursor.execute('use master')
  204. cursor.execute('drop database %s' % self.db_name)
  205.  
  206.  
  207. @provider()
  208. class SqliteContext(DBContext):
  209. PROVIDER = 'sqlite'
  210. enabled = True
  211.  
  212. def init_db(self):
  213. try:
  214. os.remove(self.db_path)
  215. except OSError as exc:
  216. print('Failed to drop db: %s' % exc)
  217.  
  218.  
  219. @cached_property
  220. def db_path(self):
  221. p = os.path.dirname(__file__)
  222. p = os.path.join(p, '%s.sqlite' % self.db_name)
  223. return os.path.abspath(p)
  224.  
  225. @cached_property
  226. def db(self):
  227. return Database('sqlite', self.db_path, create_db=True)
  228.  
  229.  
  230. @provider()
  231. class PostgresContext(DBContext):
  232. PROVIDER = 'postgresql'
  233.  
  234. def get_conn_dict(self, no_db=False):
  235. d = dict(
  236. user='ponytest', password='ponytest',
  237. host='localhost', database='postgres',
  238. )
  239. if not no_db:
  240. d.update(database=self.db_name)
  241. return d
  242.  
  243. def init_db(self):
  244. import psycopg2
  245. conn = psycopg2.connect(
  246. **self.get_conn_dict(no_db=True)
  247. )
  248. conn.set_isolation_level(0)
  249. with closing(conn.cursor()) as cursor:
  250. try:
  251. self.drop_db(cursor)
  252. except psycopg2.DatabaseError as exc:
  253. print('Failed to drop db: %s' % exc)
  254. cursor.execute('create database %s' % self.db_name)
  255.  
  256. def drop_db(self, cursor):
  257. cursor.execute('drop database %s' % self.db_name)
  258.  
  259.  
  260. @cached_property
  261. def db(self):
  262. return Database('postgres', **self.get_conn_dict())
  263.  
  264.  
  265. @provider()
  266. class OracleContext(DBContext):
  267. PROVIDER = 'oracle'
  268.  
  269. def __enter__(self):
  270. os.environ.update(dict(
  271. ORACLE_BASE='/u01/app/oracle',
  272. ORACLE_HOME='/u01/app/oracle/product/12.1.0/dbhome_1',
  273. ORACLE_OWNR='oracle',
  274. ORACLE_SID='orcl',
  275. ))
  276. return super(OracleContext, self).__enter__()
  277.  
  278. def init_db(self):
  279.  
  280. import cx_Oracle
  281. with closing(self.connect_sys()) as conn:
  282. with closing(conn.cursor()) as cursor:
  283. try:
  284. self._destroy_test_user(cursor)
  285. except cx_Oracle.DatabaseError as exc:
  286. print('Failed to drop user: %s' % exc)
  287. try:
  288. self._drop_tablespace(cursor)
  289. except cx_Oracle.DatabaseError as exc:
  290. print('Failed to drop db: %s' % exc)
  291. cursor.execute(
  292. """CREATE TABLESPACE %(tblspace)s
  293. DATAFILE '%(datafile)s' SIZE 20M
  294. REUSE AUTOEXTEND ON NEXT 10M MAXSIZE %(maxsize)s
  295. """ % self.parameters)
  296. cursor.execute(
  297. """CREATE TEMPORARY TABLESPACE %(tblspace_temp)s
  298. TEMPFILE '%(datafile_tmp)s' SIZE 20M
  299. REUSE AUTOEXTEND ON NEXT 10M MAXSIZE %(maxsize_tmp)s
  300. """ % self.parameters)
  301. self._create_test_user(cursor)
  302.  
  303.  
  304. def _drop_tablespace(self, cursor):
  305. cursor.execute(
  306. 'DROP TABLESPACE %(tblspace)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS'
  307. % self.parameters)
  308. cursor.execute(
  309. 'DROP TABLESPACE %(tblspace_temp)s INCLUDING CONTENTS AND DATAFILES CASCADE CONSTRAINTS'
  310. % self.parameters)
  311.  
  312.  
  313. parameters = {
  314. 'tblspace': 'test_tblspace',
  315. 'tblspace_temp': 'test_tblspace_temp',
  316. 'datafile': 'test_datafile.dbf',
  317. 'datafile_tmp': 'test_datafile_tmp.dbf',
  318. 'user': 'ponytest',
  319. 'password': 'ponytest',
  320. 'maxsize': '100M',
  321. 'maxsize_tmp': '100M',
  322. }
  323.  
  324. def connect_sys(self):
  325. import cx_Oracle
  326. return cx_Oracle.connect('sys/the@localhost/ORCL', mode=cx_Oracle.SYSDBA)
  327.  
  328. def connect_test(self):
  329. import cx_Oracle
  330. return cx_Oracle.connect('ponytest/ponytest@localhost/ORCL')
  331.  
  332.  
  333. @cached_property
  334. def db(self):
  335. return Database('oracle', 'ponytest/ponytest@localhost/ORCL')
  336.  
  337. def _create_test_user(self, cursor):
  338. cursor.execute(
  339. """CREATE USER %(user)s
  340. IDENTIFIED BY %(password)s
  341. DEFAULT TABLESPACE %(tblspace)s
  342. TEMPORARY TABLESPACE %(tblspace_temp)s
  343. QUOTA UNLIMITED ON %(tblspace)s
  344. """ % self.parameters
  345. )
  346. cursor.execute(
  347. """GRANT CREATE SESSION,
  348. CREATE TABLE,
  349. CREATE SEQUENCE,
  350. CREATE PROCEDURE,
  351. CREATE TRIGGER
  352. TO %(user)s
  353. """ % self.parameters
  354. )
  355.  
  356. def _destroy_test_user(self, cursor):
  357. cursor.execute('''
  358. DROP USER %(user)s CASCADE
  359. ''' % self.parameters)
  360.  
  361.  
  362. @provider(__fixture__='log', weight=100, enabled=False)
  363. @contextmanager
  364. def logging_context(test):
  365. level = logging.getLogger().level
  366. from pony.orm.core import debug, sql_debug
  367. logging.getLogger().setLevel(logging.INFO)
  368. sql_debug(True)
  369. yield
  370. logging.getLogger().setLevel(level)
  371. sql_debug(debug)
  372.  
  373. # @provider('log_all', scope='class', weight=-100, enabled=False)
  374. # def log_all(Test):
  375. # return logging_context(Test)
  376.  
  377.  
  378.  
  379. # @with_cli_args
  380. # @click.option('--log', 'scope', flag_value='test')
  381. # @click.option('--log-all', 'scope', flag_value='all')
  382. # def use_logging(scope):
  383. # if scope == 'test':
  384. # yield logging_context
  385. # elif scope =='all':
  386. # yield log_all
  387.  
  388.  
  389.  
  390.  
  391. @provider()
  392. class DBSessionProvider(object):
  393.  
  394. __fixture__= 'db_session'
  395.  
  396. weight = 30
  397.  
  398. def __new__(cls, test):
  399. return db_session
  400.  
  401.  
  402. @provider(__fixture__='rollback', weight=40)
  403. @contextmanager
  404. def do_rollback(test):
  405. try:
  406. yield
  407. finally:
  408. rollback()
  409.  
  410.  
  411. @provider()
  412. class SeparateProcess(object):
  413.  
  414. # TODO read failures from sep process better
  415.  
  416. __fixture__ = 'separate_process'
  417.  
  418. enabled = False
  419.  
  420. scope = 'class'
  421.  
  422. def __init__(self, Test):
  423. self.Test = Test
  424.  
  425. def __call__(self, func):
  426. def wrapper(Test):
  427. rnr = unittest.runner.TextTestRunner()
  428. TestCls = Test if isinstance(Test, type) else type(Test)
  429. def runTest(self):
  430. try:
  431. func(Test)
  432. finally:
  433. rnr.stream = unittest.runner._WritelnDecorator(StringIO())
  434. name = getattr(func, '__name__', 'runTest')
  435. Case = type(TestCls.__name__, (TestCls,), {name: runTest})
  436. Case.__module__ = TestCls.__module__
  437. case = Case(name)
  438. suite = unittest.suite.TestSuite([case])
  439. def run():
  440. result = rnr.run(suite)
  441. if not result.wasSuccessful():
  442. sys.exit(1)
  443. p = Process(target=run, args=())
  444. p.start()
  445. p.join()
  446. case.assertEqual(p.exitcode, 0)
  447. return wrapper
  448.  
  449. @classmethod
  450. def validate_chain(cls, fixtures, klass):
  451. for f in fixtures:
  452. if f.KEY in ('ipdb', 'ipdb_all'):
  453. return False
  454. for f in fixtures:
  455. if f.KEY == 'db' and f.PROVIDER in ('sqlserver', 'oracle'):
  456. return True
  457.  
  458. @provider()
  459. class ClearTables(ContextDecorator):
  460.  
  461. __fixture__ = 'clear_tables'
  462.  
  463. def __init__(self, test):
  464. self.test = test
  465.  
  466. def __enter__(self):
  467. pass
  468.  
  469. @db_session
  470. def __exit__(self, *exc_info):
  471. db = self.test.db
  472. for entity in db.entities.values():
  473. if entity._database_.schema is None:
  474. break
  475. delete(i for i in entity)
  476.  
  477.  
  478. @provider()
  479. class NoJson1(ContextDecorator):
  480.  
  481. __fixture__ = 'no_json1'
  482.  
  483. def __init__(self, cls):
  484. self.Test = cls
  485. cls.no_json1 = True
  486.  
  487. fixture_name = 'no_json1'
  488.  
  489. def __enter__(self):
  490. self.json1_available = self.Test.db.provider.json1_available
  491. self.Test.db.provider.json1_available = False
  492.  
  493. def __exit__(self, *exc_info):
  494. self.Test.db.provider.json1_available = self.json1_available
  495.  
  496. scope = 'class'
  497.  
  498. @classmethod
  499. def validate_chain(cls, fixtures, klass):
  500. for f in fixtures:
  501. if f.KEY in ('ipdb', 'ipdb_all'):
  502. return False
  503. for f in fixtures:
  504. if f.KEY == 'db' and f.PROVIDER in ('sqlserver', 'oracle'):
  505. return True
  506.  
  507.  
  508. import signal
  509.  
  510. @provider()
  511. class Timeout(object):
  512.  
  513. __fixture__ = 'timeout'
  514.  
  515. @with_cli_args
  516. @click.option('--timeout', type=int)
  517. def __init__(self, Test, timeout):
  518. self.Test = Test
  519. self.timeout = timeout if timeout else Test.TIMEOUT
  520.  
  521. scope = 'class'
  522. enabled = False
  523.  
  524. class Exception(Exception):
  525. pass
  526.  
  527. class FailedInSubprocess(Exception):
  528. pass
  529.  
  530. def __call__(self, func):
  531. def wrapper(test):
  532. p = Process(target=func, args=(test,))
  533. p.start()
  534.  
  535. def on_expired():
  536. p.terminate()
  537.  
  538. t = threading.Timer(self.timeout, on_expired)
  539. t.start()
  540. p.join()
  541. t.cancel()
  542. if p.exitcode == -signal.SIGTERM:
  543. raise self.Exception
  544. elif p.exitcode:
  545. raise self.FailedInSubprocess
  546.  
  547. return wrapper
  548.  
  549. @classmethod
  550. @with_cli_args
  551. @click.option('--timeout', type=int)
  552. def validate_chain(cls, fixtures, klass, timeout):
  553. if not getattr(klass, 'TIMEOUT', None) and not timeout:
  554. return False
  555. for f in fixtures:
  556. if f.KEY in ('ipdb', 'ipdb_all'):
  557. return False
  558. for f in fixtures:
  559. if f.KEY == 'db' and f.PROVIDER in ('sqlserver', 'oracle'):
  560. return True
  561.  
  562.  
  563. pony_fixtures['test'].extend([
  564. 'log',
  565. 'clear_tables',
  566. 'db_session',
  567. ])
  568.  
  569. pony_fixtures['class'].extend([
  570. 'separate_process',
  571. 'timeout',
  572. 'db',
  573. 'generate_mapping',
  574. ])
  575.  
  576. def db_is_present(providers, config):
  577. return providers
  578.  
  579. provider_validators.update({
  580. 'db': db_is_present,
  581. })
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement