Advertisement
mrmamongo

session.py

Jun 8th, 2023
787
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.05 KB | None | 0 0
  1. import contextlib
  2. from typing import AsyncIterator
  3.  
  4. from sqlalchemy.ext.asyncio import (
  5.     AsyncConnection,
  6.     AsyncEngine,
  7.     AsyncSession,
  8.     async_sessionmaker,
  9.     create_async_engine,
  10. )
  11.  
  12. from src.db.models import Base
  13.  
  14.  
  15. class DatabaseSessionManager:
  16.     def __init__(self):
  17.         self._engine: AsyncEngine | None = None
  18.         self._sessionmaker: async_sessionmaker | None = None
  19.  
  20.     def init(self, host: str):
  21.         self._engine = create_async_engine(host)
  22.         self._sessionmaker = async_sessionmaker(autocommit=False, bind=self._engine)
  23.  
  24.     async def close(self):
  25.         if self._engine is None:
  26.             raise Exception("DatabaseSessionManager is not initialized")
  27.         await self._engine.dispose()
  28.         self._engine = None
  29.         self._sessionmaker = None
  30.  
  31.     @contextlib.asynccontextmanager
  32.     async def connect(self) -> AsyncIterator[AsyncConnection]:
  33.         if self._engine is None:
  34.             raise Exception("DatabaseSessionManager is not initialized")
  35.  
  36.         async with self._engine.begin() as connection:
  37.             try:
  38.                 yield connection
  39.             except Exception:
  40.                 await connection.rollback()
  41.                 raise
  42.  
  43.     @contextlib.asynccontextmanager
  44.     async def session(self) -> AsyncIterator[AsyncSession]:
  45.         if self._sessionmaker is None:
  46.             raise Exception("DatabaseSessionManager is not initialized")
  47.  
  48.         session = self._sessionmaker()
  49.         try:
  50.             yield session
  51.         except Exception:
  52.             await session.rollback()
  53.             raise
  54.         finally:
  55.             await session.close()
  56.  
  57.     # Used for testing
  58.     async def create_all(self, connection: AsyncConnection):
  59.         await connection.run_sync(Base.metadata.create_all)
  60.  
  61.     async def drop_all(self, connection: AsyncConnection):
  62.         await connection.run_sync(Base.metadata.drop_all)
  63.  
  64.  
  65. sessionmanager = DatabaseSessionManager()
  66.  
  67.  
  68. async def get_session():
  69.     async with sessionmanager.session() as session:
  70.         yield session
  71.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement