service_man/api/uow/repository.py

38 lines
1.0 KiB
Python
Raw Normal View History

2024-03-06 02:28:59 +03:00
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from sqlalchemy import insert, select
from sqlalchemy.ext.asyncio import AsyncSession
2024-03-06 03:59:16 +03:00
import api.models as models
2024-03-06 02:28:59 +03:00
2024-03-06 03:59:16 +03:00
ModelType = TypeVar("ModelType", bound=models.Base)
2024-03-06 02:28:59 +03:00
class AbstractRepository(ABC):
@abstractmethod
async def add_one(self, data: dict):
2024-03-06 03:59:16 +03:00
raise NotImplementedError()
2024-03-06 02:28:59 +03:00
@abstractmethod
async def find_all(self):
2024-03-06 03:59:16 +03:00
raise NotImplementedError()
2024-03-06 02:28:59 +03:00
class SQLAlchemyRepository(AbstractRepository, Generic[ModelType]):
model: type[ModelType]
def __init__(self, session: AsyncSession):
self.session = session
2024-03-06 03:59:16 +03:00
async def add_one(self, data: dict) -> ModelType:
stmt = insert(self.model).values(**data)
2024-03-06 02:28:59 +03:00
res = await self.session.execute(stmt)
return res.scalar_one()
2024-03-06 03:59:16 +03:00
async def find_all(self) -> list[ModelType]:
2024-03-06 02:28:59 +03:00
stmt = select(self.model)
res = await self.session.execute(stmt)
res = [row[0].to_read_model() for row in res.all()]
return res