test_api/test_api/uow/uow_base.py

34 lines
826 B
Python

from types import TracebackType
from fastapi import HTTPException
from ..repositories import RefRepository, UserRepository
class UnitOfWork:
def __init__(self, session_factory):
self.session = session_factory()
self.users = UserRepository(self.session)
self.ref = RefRepository(self.session)
async def __aenter__(self):
return self
async def __aexit__(
self,
exc_type: type[BaseException],
exc_val: BaseException,
exc_tb: TracebackType,
) -> None:
if exc_type:
await self.rollback()
raise HTTPException(400, "Very Bad Request")
else:
await self.commit()
async def commit(self):
await self.session.commit()
async def rollback(self):
await self.session.rollback()