uow and di basic implementation

This commit is contained in:
2024-03-06 02:28:59 +03:00
parent 402daf63d1
commit 8d93c964e1
21 changed files with 217 additions and 100 deletions

View File

@@ -1,37 +0,0 @@
from contextlib import (AbstractContextManager, asynccontextmanager,
contextmanager)
from typing import Callable
from sqlalchemy.ext.asyncio import (AsyncSession, async_sessionmaker,
create_async_engine)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session
Base = declarative_base()
class Database:
def __init__(self, db_url: str) -> None:
self._engine = create_async_engine(db_url, echo=True)
self._session_factory = async_sessionmaker(
self._engine,
class_=AsyncSession,
expire_on_commit=False,
)
@property
def session(self):
return self._session_factory()
async def __aenter__(self):
return self
async def __aexit__(self, *args):
await self.session.rollback()
await self.session.close()
async def commit(self):
await self.session.commit()
async def rollback(self):
await self.session.rollback()

38
api/uow/repository.py Normal file
View File

@@ -0,0 +1,38 @@
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from uuid import UUID
from sqlalchemy import insert, select
from sqlalchemy.ext.asyncio import AsyncSession
from api.model.base import Base
ModelType = TypeVar("ModelType", bound=Base)
class AbstractRepository(ABC):
@abstractmethod
async def add_one(self, data: dict):
raise NotImplementedError
@abstractmethod
async def find_all(self):
raise NotImplementedError
class SQLAlchemyRepository(AbstractRepository, Generic[ModelType]):
model: type[ModelType]
def __init__(self, session: AsyncSession):
self.session = session
async def add_one(self, data: dict) -> UUID:
stmt = insert(self.model).values(**data).returning(self.model.id)
res = await self.session.execute(stmt)
return res.scalar_one()
async def find_all(self):
stmt = select(self.model)
res = await self.session.execute(stmt)
res = [row[0].to_read_model() for row in res.all()]
return res

View File

@@ -1,23 +1,47 @@
from contextlib import AbstractContextManager
from typing import Iterable
from abc import ABC, abstractmethod
from dependency_injector.providers import Callable
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from sqlalchemy.orm import Session
from api.model.user import User
from api.repository.user import UserRepository
class UowBase:
def __init__(
self,
session_factory,
) -> None:
self.session = session_factory
class IUnitOfWork(ABC):
users: type[UserRepository]
async def get_all_users(self):
async with self.session as s:
query = select(User)
rr = await s.execute(query)
return rr.scalars().all()
@abstractmethod
def __init__(self):
...
@abstractmethod
async def __aenter__(self):
...
@abstractmethod
async def __aexit__(self):
...
@abstractmethod
async def commit(self):
...
@abstractmethod
async def rollback(self):
...
class UnitOfWork:
def __init__(self, session_factory):
self.session_factory = session_factory
async def __aenter__(self):
self.session = self.session_factory()
self.users = UserRepository(self.session)
async def __aexit__(self, *args):
await self.session.rollback()
await self.session.close()
async def commit(self):
await self.session.commit()
async def rollback(self):
await self.session.rollback()