uow and di basic implementation
This commit is contained in:
60
api/di.py
60
api/di.py
@@ -1,44 +1,58 @@
|
||||
import os
|
||||
|
||||
import yaml # type: ignore
|
||||
from dependency_injector import containers, providers
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from api.repository.user import UserRepository
|
||||
from api.service.user import UserService
|
||||
from api.uow.database import Database
|
||||
from api.uow.uow_base import UowBase
|
||||
from api.uow.uow_base import UnitOfWork
|
||||
|
||||
|
||||
class Container(containers.DeclarativeContainer):
|
||||
wiring_config = containers.WiringConfiguration(modules=["api.router.user"])
|
||||
|
||||
config = providers.Configuration(yaml_files=[f"{os.getenv('CONFIG_PATH')}"])
|
||||
if not os.getenv("CONFIG_PATH"):
|
||||
raise ValueError('Please set "CONFIG_PATH" variable in your environment')
|
||||
|
||||
with open(os.getenv("CONFIG_PATH", "")) as f:
|
||||
config_data = yaml.safe_load(f)
|
||||
|
||||
config = providers.Configuration()
|
||||
|
||||
if os.getenv("INDOCKER"):
|
||||
config.db.host.update("db")
|
||||
config_data["db"]["host"] = "db"
|
||||
config_data["db"]["port"] = 5432
|
||||
|
||||
db = providers.Singleton(
|
||||
Database,
|
||||
db_url="postgresql+asyncpg://{}:{}@{}:{}/{}".format(
|
||||
config.db.user,
|
||||
config.db.password,
|
||||
config.db.host,
|
||||
# config.db.port,
|
||||
"5432",
|
||||
config.db.database,
|
||||
async_engine = providers.Factory(
|
||||
create_async_engine,
|
||||
"postgresql+asyncpg://{}:{}@{}:{}/{}".format(
|
||||
config_data["db"]["user"],
|
||||
config_data["db"]["password"],
|
||||
config_data["db"]["host"],
|
||||
config_data["db"]["port"],
|
||||
config_data["db"]["database"],
|
||||
),
|
||||
echo=True,
|
||||
)
|
||||
|
||||
async_session_factory = providers.Factory(
|
||||
async_sessionmaker,
|
||||
async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
uow = providers.Factory(
|
||||
UowBase,
|
||||
session_factory=db.provided.session,
|
||||
UnitOfWork,
|
||||
session_factory=async_session_factory,
|
||||
)
|
||||
|
||||
user_repository = providers.Factory(
|
||||
UserRepository,
|
||||
uow=uow,
|
||||
)
|
||||
|
||||
#
|
||||
# user_repository = providers.Factory(
|
||||
# UserRepository,
|
||||
# uow=uow,
|
||||
# )
|
||||
#
|
||||
user_service = providers.Factory(
|
||||
UserService,
|
||||
user_repository=user_repository,
|
||||
uow=uow,
|
||||
)
|
||||
|
@@ -4,7 +4,6 @@ from alembic import context
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
|
||||
import api.model.user # type: ignore
|
||||
from api.uow.database import Base
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
@@ -19,7 +18,7 @@ if config.config_file_name is not None:
|
||||
# for 'autogenerate' support
|
||||
# from myapp import mymodel
|
||||
# target_metadata = mymodel.Base.metadata
|
||||
target_metadata = Base.metadata
|
||||
target_metadata = api.model.user.Base.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
# can be acquired:
|
||||
|
Binary file not shown.
Binary file not shown.
@@ -6,7 +6,7 @@ Create Date: 2024-03-04 03:11:36.206211
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
12
api/model/base.py
Normal file
12
api/model/base.py
Normal file
@@ -0,0 +1,12 @@
|
||||
import uuid
|
||||
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
id: Mapped[uuid.UUID] = mapped_column(
|
||||
UUID(as_uuid=True),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4,
|
||||
)
|
@@ -1,12 +1,11 @@
|
||||
from sqlalchemy import Boolean, Column, Integer, String
|
||||
from sqlalchemy import Boolean, Column, String
|
||||
|
||||
from api.uow.database import Base
|
||||
from .base import Base
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "users"
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
email = Column(String, unique=True)
|
||||
hashed_password = Column(String)
|
||||
is_active = Column(Boolean, default=True)
|
||||
|
@@ -1,9 +1,6 @@
|
||||
from api.uow.uow_base import UowBase
|
||||
from api.model.user import User
|
||||
from api.uow.repository import SQLAlchemyRepository
|
||||
|
||||
|
||||
class UserRepository:
|
||||
def __init__(self, uow: UowBase) -> None:
|
||||
self.uow = uow
|
||||
|
||||
async def get_all_users(self):
|
||||
return await self.uow.get_all_users()
|
||||
class UserRepository(SQLAlchemyRepository):
|
||||
model = User
|
||||
|
@@ -1,10 +1,10 @@
|
||||
from api.repository.user import UserRepository
|
||||
from api.uow.uow_base import UowBase
|
||||
from api.uow.uow_base import IUnitOfWork
|
||||
|
||||
|
||||
class UserService:
|
||||
def __init__(self, user_repository: UserRepository) -> None:
|
||||
self.user_repository = user_repository
|
||||
def __init__(self, uow: IUnitOfWork):
|
||||
self.uow = uow
|
||||
|
||||
async def get_all_users(self):
|
||||
return await self.user_repository.get_all_users()
|
||||
async with self.uow:
|
||||
await self.uow.users.find_all()
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,37 +0,0 @@
|
||||
from contextlib import (AbstractContextManager, asynccontextmanager,
|
||||
contextmanager)
|
||||
from typing import Callable
|
||||
|
||||
from sqlalchemy.ext.asyncio import (AsyncSession, async_sessionmaker,
|
||||
create_async_engine)
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self, db_url: str) -> None:
|
||||
self._engine = create_async_engine(db_url, echo=True)
|
||||
self._session_factory = async_sessionmaker(
|
||||
self._engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
@property
|
||||
def session(self):
|
||||
return self._session_factory()
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
await self.session.rollback()
|
||||
await self.session.close()
|
||||
|
||||
async def commit(self):
|
||||
await self.session.commit()
|
||||
|
||||
async def rollback(self):
|
||||
await self.session.rollback()
|
38
api/uow/repository.py
Normal file
38
api/uow/repository.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import insert, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from api.model.base import Base
|
||||
|
||||
ModelType = TypeVar("ModelType", bound=Base)
|
||||
|
||||
|
||||
class AbstractRepository(ABC):
|
||||
@abstractmethod
|
||||
async def add_one(self, data: dict):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
async def find_all(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SQLAlchemyRepository(AbstractRepository, Generic[ModelType]):
|
||||
model: type[ModelType]
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def add_one(self, data: dict) -> UUID:
|
||||
stmt = insert(self.model).values(**data).returning(self.model.id)
|
||||
res = await self.session.execute(stmt)
|
||||
return res.scalar_one()
|
||||
|
||||
async def find_all(self):
|
||||
stmt = select(self.model)
|
||||
res = await self.session.execute(stmt)
|
||||
res = [row[0].to_read_model() for row in res.all()]
|
||||
return res
|
@@ -1,23 +1,47 @@
|
||||
from contextlib import AbstractContextManager
|
||||
from typing import Iterable
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from dependency_injector.providers import Callable
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from api.model.user import User
|
||||
from api.repository.user import UserRepository
|
||||
|
||||
|
||||
class UowBase:
|
||||
def __init__(
|
||||
self,
|
||||
session_factory,
|
||||
) -> None:
|
||||
self.session = session_factory
|
||||
class IUnitOfWork(ABC):
|
||||
users: type[UserRepository]
|
||||
|
||||
async def get_all_users(self):
|
||||
async with self.session as s:
|
||||
query = select(User)
|
||||
rr = await s.execute(query)
|
||||
return rr.scalars().all()
|
||||
@abstractmethod
|
||||
def __init__(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def __aenter__(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def __aexit__(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def commit(self):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def rollback(self):
|
||||
...
|
||||
|
||||
|
||||
class UnitOfWork:
|
||||
def __init__(self, session_factory):
|
||||
self.session_factory = session_factory
|
||||
|
||||
async def __aenter__(self):
|
||||
self.session = self.session_factory()
|
||||
|
||||
self.users = UserRepository(self.session)
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
await self.session.rollback()
|
||||
await self.session.close()
|
||||
|
||||
async def commit(self):
|
||||
await self.session.commit()
|
||||
|
||||
async def rollback(self):
|
||||
await self.session.rollback()
|
||||
|
Reference in New Issue
Block a user