2024-03-06 02:28:59 +03:00
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
from typing import Generic, TypeVar
|
|
|
|
|
|
|
|
from sqlalchemy import insert, select
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
2024-03-06 03:59:16 +03:00
|
|
|
import api.models as models
|
2024-03-06 02:28:59 +03:00
|
|
|
|
2024-03-06 03:59:16 +03:00
|
|
|
ModelType = TypeVar("ModelType", bound=models.Base)
|
2024-03-06 02:28:59 +03:00
|
|
|
|
|
|
|
|
|
|
|
class AbstractRepository(ABC):
|
|
|
|
@abstractmethod
|
|
|
|
async def add_one(self, data: dict):
|
2024-03-06 03:59:16 +03:00
|
|
|
raise NotImplementedError()
|
2024-03-06 02:28:59 +03:00
|
|
|
|
|
|
|
@abstractmethod
|
|
|
|
async def find_all(self):
|
2024-03-06 03:59:16 +03:00
|
|
|
raise NotImplementedError()
|
2024-03-06 02:28:59 +03:00
|
|
|
|
|
|
|
|
|
|
|
class SQLAlchemyRepository(AbstractRepository, Generic[ModelType]):
|
|
|
|
model: type[ModelType]
|
|
|
|
|
|
|
|
def __init__(self, session: AsyncSession):
|
|
|
|
self.session = session
|
|
|
|
|
2024-03-06 03:59:16 +03:00
|
|
|
async def add_one(self, data: dict) -> ModelType:
|
|
|
|
stmt = insert(self.model).values(**data)
|
2024-03-06 02:28:59 +03:00
|
|
|
res = await self.session.execute(stmt)
|
|
|
|
return res.scalar_one()
|
|
|
|
|
2024-03-06 03:59:16 +03:00
|
|
|
async def find_all(self) -> list[ModelType]:
|
2024-03-06 02:28:59 +03:00
|
|
|
stmt = select(self.model)
|
|
|
|
res = await self.session.execute(stmt)
|
|
|
|
res = [row[0].to_read_model() for row in res.all()]
|
|
|
|
return res
|