Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import contextlib
- from typing import AsyncIterator
- from sqlalchemy.ext.asyncio import (
- AsyncConnection,
- AsyncEngine,
- AsyncSession,
- async_sessionmaker,
- create_async_engine,
- )
- from src.db.models import Base
- class DatabaseSessionManager:
- def __init__(self):
- self._engine: AsyncEngine | None = None
- self._sessionmaker: async_sessionmaker | None = None
- def init(self, host: str):
- self._engine = create_async_engine(host)
- self._sessionmaker = async_sessionmaker(autocommit=False, bind=self._engine)
- async def close(self):
- if self._engine is None:
- raise Exception("DatabaseSessionManager is not initialized")
- await self._engine.dispose()
- self._engine = None
- self._sessionmaker = None
- @contextlib.asynccontextmanager
- async def connect(self) -> AsyncIterator[AsyncConnection]:
- if self._engine is None:
- raise Exception("DatabaseSessionManager is not initialized")
- async with self._engine.begin() as connection:
- try:
- yield connection
- except Exception:
- await connection.rollback()
- raise
- @contextlib.asynccontextmanager
- async def session(self) -> AsyncIterator[AsyncSession]:
- if self._sessionmaker is None:
- raise Exception("DatabaseSessionManager is not initialized")
- session = self._sessionmaker()
- try:
- yield session
- except Exception:
- await session.rollback()
- raise
- finally:
- await session.close()
- # Used for testing
- async def create_all(self, connection: AsyncConnection):
- await connection.run_sync(Base.metadata.create_all)
- async def drop_all(self, connection: AsyncConnection):
- await connection.run_sync(Base.metadata.drop_all)
- sessionmanager = DatabaseSessionManager()
- async def get_session():
- async with sessionmanager.session() as session:
- yield session
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement