diff --git a/test_api/di.py b/test_api/di.py index f6c3ef7..380450d 100644 --- a/test_api/di.py +++ b/test_api/di.py @@ -1,4 +1,3 @@ -from fastapi.security import OAuth2PasswordBearer from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from test_api.services.auth import AuthService @@ -7,8 +6,6 @@ from .config import get_settings from .services.user import UserService from .uow.uow_base import UnitOfWork -Oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") - async_engine = create_async_engine( url=get_settings().db.get_db_url, echo=False, diff --git a/test_api/routers/user.py b/test_api/routers/user.py index b10e713..cd70462 100644 --- a/test_api/routers/user.py +++ b/test_api/routers/user.py @@ -23,7 +23,6 @@ async def get_user_list( user_service: UserService = Depends(get_user_service), user: UserDTO = Depends(get_current_user), ) -> list[UserReadDTO] | str: - print(user) return await user_service.get_all_users() @@ -56,6 +55,8 @@ async def patch_user( user_service: UserService = Depends(get_user_service), user: UserDTO = Depends(get_current_user), ) -> UserReadDTO | dict: + if user_uuid != user.id: + raise HTTPException(401, "No premission") res = await user_service.patch_one(id=user_uuid, data=data) if not isinstance(res, UserReadDTO): raise HTTPException(status_code=400, detail=res) diff --git a/test_api/services/user.py b/test_api/services/user.py index df1c025..792987d 100644 --- a/test_api/services/user.py +++ b/test_api/services/user.py @@ -1,7 +1,6 @@ from uuid import UUID from passlib.context import CryptContext -from sqlalchemy.exc import IntegrityError from ..schemas.user_schema import UserDBDTO, UserWriteDTO from ..uow.uow_base import UnitOfWork @@ -14,41 +13,21 @@ class UserService: async def get_all_users(self): async with self.uow: - try: - res = await self.uow.users.find_all() - except IntegrityError as e: - await self.uow.rollback() - res = e._message() - else: - await self.uow.commit() - finally: - return res + res = await self.uow.users.find_all() + return res async def add_one(self, data: UserWriteDTO): new_data = data.model_dump() new_data["hashed_password"] = self.crypto_context.hash(new_data.pop("password")) dataf: UserDBDTO = UserDBDTO(**new_data) async with self.uow: - try: - res = await self.uow.users.add_one(data=dataf) - - except IntegrityError as e: - await self.uow.rollback() - res = e._message() - else: - await self.uow.commit() - finally: - return res + res = await self.uow.users.add_one(data=dataf) + return res async def get_user(self, id: UUID): async with self.uow: - try: - res = await self.uow.users.find_one(filter={"id": id}) - except IntegrityError as e: - await self.uow.rollback() - res = e._message() - finally: - return res + res = await self.uow.users.find_one(filter={"id": id}) + return res async def patch_one(self, id: UUID, data: UserWriteDTO): new_data = data.model_dump() @@ -56,24 +35,10 @@ class UserService: dataf = UserDBDTO(**new_data) async with self.uow: - try: - res = await self.uow.users.patch_one(filter={"id": id}, data=dataf) - except IntegrityError as e: - await self.uow.rollback() - res = e._message() - else: - await self.uow.commit() - finally: - return res + res = await self.uow.users.patch_one(filter={"id": id}, data=dataf) + return res async def delete_user(self, id: UUID) -> None | str: async with self.uow: - try: - res = await self.uow.users.delete_one(filter={"id": id}) - except IntegrityError as e: - await self.uow.rollback() - res = e._message() - else: - await self.uow.commit() - finally: - return res + res = await self.uow.users.delete_one(filter={"id": id}) + return res diff --git a/test_api/uow/uow_base.py b/test_api/uow/uow_base.py index 32d9418..72b24ee 100644 --- a/test_api/uow/uow_base.py +++ b/test_api/uow/uow_base.py @@ -1,16 +1,29 @@ +from types import TracebackType + +from fastapi import HTTPException + from ..repositories import UserRepository class UnitOfWork: def __init__(self, session_factory): - self.session_factory = session_factory - - async def __aenter__(self): - self.session = self.session_factory() + self.session = session_factory() self.users = UserRepository(self.session) - async def __aexit__(self, *args): - await self.session.close() + async def __aenter__(self): + return self + + async def __aexit__( + self, + exc_type: type[BaseException], + exc_val: BaseException, + exc_tb: TracebackType, + ) -> None: + if exc_type: + await self.rollback() + raise HTTPException(400, "Very Bad Request") + else: + await self.commit() async def commit(self): await self.session.commit()