uow and di basic implementation

main
Сергей Ванюшкин 2024-03-06 03:59:16 +03:00
parent 8d93c964e1
commit f9631a712b
11 changed files with 50 additions and 19 deletions

View File

@ -3,7 +3,7 @@ from logging.config import fileConfig
from alembic import context
from sqlalchemy import engine_from_config, pool
import api.model.user # type: ignore
import api.models as models
# this is the Alembic Config object, which provides
# access to the values within the .ini file in use.
@ -18,7 +18,7 @@ if config.config_file_name is not None:
# for 'autogenerate' support
# from myapp import mymodel
# target_metadata = mymodel.Base.metadata
target_metadata = api.model.user.Base.metadata
target_metadata = models.Base.metadata
# other values from the config, defined by the needs of env.py,
# can be acquired:

View File

@ -1,8 +1,8 @@
"""initial
Revision ID: ec1380cb4f18
Revision ID: 3ba730985688
Revises:
Create Date: 2024-03-04 03:11:36.206211
Create Date: 2024-03-06 03:10:09.050166
"""
@ -12,7 +12,7 @@ import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "ec1380cb4f18"
revision: str = "3ba730985688"
down_revision: str | None = None
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
@ -22,10 +22,11 @@ def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"users",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("email", sa.String(), nullable=True),
sa.Column("hashed_password", sa.String(), nullable=True),
sa.Column("is_active", sa.Boolean(), nullable=True),
sa.Column("name", sa.String(), nullable=False),
sa.Column("id", sa.UUID(), nullable=False),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("email"),
)

View File

7
api/models/__init__.py Normal file
View File

@ -0,0 +1,7 @@
from .base import Base
from .user import User
__all__ = (
"Base",
"User",
)

View File

@ -1,4 +1,7 @@
from sqlalchemy import Boolean, Column, String
from sqlalchemy.orm import Mapped
from api.schemas.user_schema import UserSchema
from .base import Base
@ -6,6 +9,7 @@ from .base import Base
class User(Base):
__tablename__ = "users"
name: Mapped[str]
email = Column(String, unique=True)
hashed_password = Column(String)
is_active = Column(Boolean, default=True)
@ -17,3 +21,9 @@ class User(Base):
f'hashed_password="{self.hashed_password}", '
f"is_active={self.is_active})>"
)
def to_read_model(self) -> UserSchema:
return UserSchema(
id=self.id,
name=self.name,
)

View File

@ -1,6 +1,6 @@
from api.model.user import User
import api.models as models
from api.uow.repository import SQLAlchemyRepository
class UserRepository(SQLAlchemyRepository):
model = User
model = models.User

View File

@ -2,16 +2,17 @@ from dependency_injector.wiring import Provide, inject
from fastapi import APIRouter, Depends
from api.di import Container
from api.schemas.user_schema import UserSchema
from api.service.user import UserService
router = APIRouter()
@router.get("/users")
@router.get("/users", response_model=list[UserSchema])
@inject
async def get_user_list(
user_service: UserService = Depends(Provide[Container.user_service]),
):
) -> list[UserSchema]:
return await user_service.get_all_users()

View File

@ -0,0 +1,11 @@
from uuid import UUID
from pydantic import BaseModel
class UserSchema(BaseModel):
id: UUID
name: str
class Config:
from_attributes = True

View File

@ -7,4 +7,6 @@ class UserService:
async def get_all_users(self):
async with self.uow:
await self.uow.users.find_all()
res = await self.uow.users.find_all()
return res

View File

@ -1,23 +1,22 @@
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from uuid import UUID
from sqlalchemy import insert, select
from sqlalchemy.ext.asyncio import AsyncSession
from api.model.base import Base
import api.models as models
ModelType = TypeVar("ModelType", bound=Base)
ModelType = TypeVar("ModelType", bound=models.Base)
class AbstractRepository(ABC):
@abstractmethod
async def add_one(self, data: dict):
raise NotImplementedError
raise NotImplementedError()
@abstractmethod
async def find_all(self):
raise NotImplementedError
raise NotImplementedError()
class SQLAlchemyRepository(AbstractRepository, Generic[ModelType]):
@ -26,12 +25,12 @@ class SQLAlchemyRepository(AbstractRepository, Generic[ModelType]):
def __init__(self, session: AsyncSession):
self.session = session
async def add_one(self, data: dict) -> UUID:
stmt = insert(self.model).values(**data).returning(self.model.id)
async def add_one(self, data: dict) -> ModelType:
stmt = insert(self.model).values(**data)
res = await self.session.execute(stmt)
return res.scalar_one()
async def find_all(self):
async def find_all(self) -> list[ModelType]:
stmt = select(self.model)
res = await self.session.execute(stmt)
res = [row[0].to_read_model() for row in res.all()]