from types import TracebackType from sqlalchemy.ext.asyncio import AsyncSession from api.infrastructure.persistence.error import TransactionContextManagerError class SqlalchemyTransactionContextManager: def __init__(self, session: AsyncSession): self._session = session async def __aenter__(self): return self async def __aexit__( self, exc_type: type[BaseException], exc_val: BaseException, exc_tb: TracebackType, ) -> None: if exc_type: await self.rollback() raise TransactionContextManagerError(message="Transaction Error") async def commit(self): await self._session.commit() async def rollback(self): await self._session.rollback()