93 lines
2.4 KiB
Python
93 lines
2.4 KiB
Python
|
from collections.abc import AsyncIterable, Callable
|
||
|
from typing import Annotated
|
||
|
|
||
|
from fastapi import Depends
|
||
|
from sqlalchemy.ext.asyncio import (
|
||
|
AsyncEngine,
|
||
|
AsyncSession,
|
||
|
async_sessionmaker,
|
||
|
create_async_engine,
|
||
|
)
|
||
|
|
||
|
from api.config import get_settings
|
||
|
|
||
|
|
||
|
class Stub:
|
||
|
"""
|
||
|
This class is used to prevent fastapi from digging into
|
||
|
real dependencies attributes detecting them as request data
|
||
|
|
||
|
So instead of
|
||
|
`interactor: Annotated[Interactor, Depends()]`
|
||
|
Write
|
||
|
`interactor: Annotated[Interactor, Depends(Stub(Interactor))]`
|
||
|
|
||
|
And then you can declare how to create it:
|
||
|
`app.dependency_overrids[Interactor] = some_real_factory`
|
||
|
|
||
|
"""
|
||
|
|
||
|
def __init__(self, dependency: Callable, **kwargs):
|
||
|
self._dependency = dependency
|
||
|
self._kwargs = kwargs
|
||
|
|
||
|
def __call__(self):
|
||
|
raise NotImplementedError
|
||
|
|
||
|
def __eq__(self, other) -> bool:
|
||
|
if isinstance(other, Stub):
|
||
|
return self._dependency == other._dependency and self._kwargs == other._kwargs
|
||
|
else:
|
||
|
if not self._kwargs:
|
||
|
return self._dependency == other
|
||
|
return False
|
||
|
|
||
|
def __hash__(self):
|
||
|
if not self._kwargs:
|
||
|
return hash(self._dependency)
|
||
|
serial = (
|
||
|
self._dependency,
|
||
|
*self._kwargs.items(),
|
||
|
)
|
||
|
return hash(serial)
|
||
|
|
||
|
|
||
|
# def new_user_repository(
|
||
|
# session: Annotated[AsyncSession, Depends(Stub(AsyncSession))],
|
||
|
# ) -> UserRepository:
|
||
|
# return SqlalchemyUserRepository(session)
|
||
|
#
|
||
|
#
|
||
|
# def new_unit_of_work(
|
||
|
# session: Annotated[AsyncSession, Depends(Stub(AsyncSession))],
|
||
|
# ) -> UnitOfWork:
|
||
|
# return SqlalchemyUnitOfWork(session)
|
||
|
#
|
||
|
|
||
|
|
||
|
def create_engine() -> AsyncEngine:
|
||
|
return create_async_engine(url=get_settings().db.db_url)
|
||
|
|
||
|
|
||
|
def create_session_maker(
|
||
|
engine: Annotated[AsyncEngine, Depends(Stub(AsyncEngine))],
|
||
|
) -> async_sessionmaker[AsyncSession]:
|
||
|
return async_sessionmaker(engine, expire_on_commit=False)
|
||
|
|
||
|
|
||
|
async def new_session(
|
||
|
session_maker: Annotated[
|
||
|
async_sessionmaker[AsyncSession],
|
||
|
Depends(Stub(async_sessionmaker[AsyncSession])),
|
||
|
],
|
||
|
) -> AsyncIterable[AsyncSession]:
|
||
|
async with session_maker() as session:
|
||
|
yield session
|
||
|
|
||
|
|
||
|
# def new_user_service(
|
||
|
# uow: Annotated[UnitOfWork, Depends()],
|
||
|
# user_repository: Annotated[UserRepository, Depends()],
|
||
|
# ) -> UserService:
|
||
|
# return UserService(uow, user_repository)
|