From 0d9d69d206cbc403305a85649750906360905bcb Mon Sep 17 00:00:00 2001 From: pi3c Date: Fri, 22 Mar 2024 07:34:33 +0300 Subject: [PATCH] sync --- test_api/app.py | 2 +- test_api/models/refer_model.py | 14 +++++++ test_api/repositories/user.py | 23 +----------- test_api/routers/auth.py | 22 +++++++++-- test_api/routers/user.py | 67 ++++++---------------------------- test_api/schemas/__init__.py | 2 + test_api/schemas/ref.py | 5 +++ test_api/services/auth.py | 14 +++---- test_api/services/user.py | 28 +++----------- 9 files changed, 64 insertions(+), 113 deletions(-) create mode 100644 test_api/models/refer_model.py create mode 100644 test_api/schemas/ref.py diff --git a/test_api/app.py b/test_api/app.py index 11abba3..e03c132 100644 --- a/test_api/app.py +++ b/test_api/app.py @@ -15,6 +15,6 @@ async def lifespan(app: FastAPI): def create_app() -> FastAPI: app = FastAPI(lifespan=lifespan) - app.include_router(user_router) app.include_router(auth_router) + app.include_router(user_router) return app diff --git a/test_api/models/refer_model.py b/test_api/models/refer_model.py new file mode 100644 index 0000000..482f864 --- /dev/null +++ b/test_api/models/refer_model.py @@ -0,0 +1,14 @@ +from uuid import UUID + +from sqlalchemy import ForeignKey +from sqlalchemy.orm import Mapped, mapped_column + +from . import Base + + +class ReferModel(Base): + __tablename__ = "referals" + + owner: Mapped[UUID] = mapped_column(ForeignKey("user.id")) + email: Mapped[str] = mapped_column(unique=True) + hashed_password: Mapped[str] diff --git a/test_api/repositories/user.py b/test_api/repositories/user.py index c888e90..af7f259 100644 --- a/test_api/repositories/user.py +++ b/test_api/repositories/user.py @@ -1,4 +1,4 @@ -from sqlalchemy import delete, insert, select, update +from sqlalchemy import insert, select from sqlalchemy.ext.asyncio.session import AsyncSession from ..models import UserModel @@ -17,32 +17,11 @@ class UserRepository: res = await self.session.execute(stmt) return UserReadDTO.model_validate(res.scalar_one()) - async def find_all(self) -> list[UserReadDTO]: - stmt = select(UserModel) - res = await self.session.execute(stmt) - res = [UserReadDTO.model_validate(row) for row in res.scalars().all()] - return res - async def find_one(self, filter: dict) -> UserReadDTO | None: stmt = select(UserModel).filter_by(**filter) res = await self.session.execute(stmt) res = res.scalar_one_or_none() - print(res) if res is not None: return UserReadDTO.model_validate(res) return None - - async def patch_one(self, filter: dict, data: UserDBDTO) -> UserReadDTO | None: - stmt = update(UserModel).where(UserModel.id == filter["id"]).values(**data.model_dump()).returning(UserModel) - res = await self.session.execute(stmt) - res = res.scalar_one_or_none() - - if res is not None: - return UserReadDTO.model_validate(res) - return None - - async def delete_one(self, filter: dict) -> None: - stmt = delete(UserModel).filter_by(**filter) - await self.session.execute(stmt) - return None diff --git a/test_api/routers/auth.py b/test_api/routers/auth.py index 9481950..9885cfe 100644 --- a/test_api/routers/auth.py +++ b/test_api/routers/auth.py @@ -1,17 +1,31 @@ -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, HTTPException from fastapi.security import OAuth2PasswordRequestForm from test_api.schemas.token import TokenSchema +from test_api.schemas.user_schema import UserReadDTO, UserWriteDTO from test_api.services.auth import AuthService +from test_api.services.user import UserService -from ..di import get_auth_service +from ..di import get_auth_service, get_user_service -router = APIRouter(prefix="/token") +router = APIRouter() -@router.post("", response_model=TokenSchema) +@router.post("/token", response_model=TokenSchema, tags=["auth"]) async def authenticate( login: OAuth2PasswordRequestForm = Depends(), auth: AuthService = Depends(get_auth_service), ) -> TokenSchema | None: return await auth.authenticate(login) + + +@router.post("/register", response_model=UserReadDTO, tags=["auth"], status_code=201) +async def register( + data: UserWriteDTO, + user_service: UserService = Depends(get_user_service), +) -> UserReadDTO: + res = await user_service.add_one(data) + if not isinstance(res, UserReadDTO): + raise HTTPException(status_code=400, detail=res) + + return res diff --git a/test_api/routers/user.py b/test_api/routers/user.py index cd70462..c58f229 100644 --- a/test_api/routers/user.py +++ b/test_api/routers/user.py @@ -1,72 +1,27 @@ -from uuid import UUID - -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, Depends from test_api.services.auth import get_current_user from ..di import get_user_service -from ..schemas import UserDTO, UserReadDTO -from ..schemas.user_schema import UserWriteDTO +from ..schemas import RefDTO, UserDTO, UserReadDTO from ..services import UserService router = APIRouter( - prefix="/api/v1", + prefix="/user", ) -@router.get( - "/users", - response_model=list[UserReadDTO], - tags=["users"], -) -async def get_user_list( - user_service: UserService = Depends(get_user_service), - user: UserDTO = Depends(get_current_user), -) -> list[UserReadDTO] | str: - return await user_service.get_all_users() - - -@router.post("/users", response_model=UserReadDTO, tags=["user"], status_code=201) -async def add_user( - data: UserWriteDTO, - user_service: UserService = Depends(get_user_service), -) -> UserReadDTO: - res = await user_service.add_one(data) - if not isinstance(res, UserReadDTO): - raise HTTPException(status_code=400, detail=res) - return res - - -@router.get("/user/{user_uuid}", response_model=UserReadDTO, tags=["user"]) +@router.get("/me", response_model=UserReadDTO, tags=["user"]) async def get_user( - user_uuid: UUID, - user_service: UserService = Depends(get_user_service), -) -> UserReadDTO | dict: - res = await user_service.get_user(id=user_uuid) - if not isinstance(res, UserReadDTO): - raise HTTPException(status_code=400, detail=res) - return res - - -@router.patch("/user/{user_uuid}", response_model=UserReadDTO, tags=["user"]) -async def patch_user( - user_uuid: UUID, - data: UserWriteDTO, user_service: UserService = Depends(get_user_service), user: UserDTO = Depends(get_current_user), -) -> UserReadDTO | dict: - if user_uuid != user.id: - raise HTTPException(401, "No premission") - res = await user_service.patch_one(id=user_uuid, data=data) - if not isinstance(res, UserReadDTO): - raise HTTPException(status_code=400, detail=res) - return res +) -> UserReadDTO | None: + return await user_service.get_user(email=user.email) -@router.delete("/user/{user_uuid}", status_code=200, tags=["user"]) -async def delete_user( - user_uuid: UUID, +@router.get("/me/create_ref", status_code=201, response_model=RefDTO) +async def create_refer_code( user_service: UserService = Depends(get_user_service), -) -> None: - await user_service.delete_user(id=user_uuid) - return None + user: UserDTO = Depends(get_current_user), +) -> RefDTO | None: + return await user_service.create_refer_code(email=user.email) diff --git a/test_api/schemas/__init__.py b/test_api/schemas/__init__.py index 1eb998c..59fa2f7 100644 --- a/test_api/schemas/__init__.py +++ b/test_api/schemas/__init__.py @@ -1,4 +1,5 @@ from .base_schema import BaseDTO, ReadDTO, WriteDTO +from .ref import RefDTO from .user_schema import UserDTO, UserReadDTO, UserWriteDTO __all__ = ( @@ -8,4 +9,5 @@ __all__ = ( "UserDTO", "UserWriteDTO", "UserReadDTO", + "RefDTO", ) diff --git a/test_api/schemas/ref.py b/test_api/schemas/ref.py new file mode 100644 index 0000000..c509071 --- /dev/null +++ b/test_api/schemas/ref.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel + + +class RefDTO(BaseModel): + refer_code: str diff --git a/test_api/services/auth.py b/test_api/services/auth.py index 9594bc0..1abc6a4 100644 --- a/test_api/services/auth.py +++ b/test_api/services/auth.py @@ -21,14 +21,14 @@ async def get_current_user(token: str = Depends(oauth2_schema)): try: payload = jwt.decode(token, "fsgddfsgdfgs", algorithms=["HS256"]) - name: str | None = payload.get("name") - sub: str | None = payload.get("sub") - expires_at: str | None = payload.get("expires_at") + name: str = payload.get("name", "") + sub: str = payload.get("sub", "") + expires_at: str = payload.get("expires_at", "") - if sub is None: + if not sub: raise HTTPException(401, "Invalid credentials") - if expires_at is not None: + if expires_at: if is_expired(expires_at): raise HTTPException(401, "Invalid credentials") @@ -49,12 +49,10 @@ class AuthService: self.crypto_context = CryptContext(schemes="bcrypt") async def authenticate(self, login: OAuth2PasswordRequestForm = Depends()) -> TokenSchema | None: - print(login) - async with self.uow: user = await self.uow.users.find_one(filter={"email": login.username}) - if user.hashed_password is None: + if user is None or user.hashed_password is None: raise HTTPException(401, "Incorrect password") else: if not self.crypto_context.verify(login.password, user.hashed_password): diff --git a/test_api/services/user.py b/test_api/services/user.py index 792987d..ff957c0 100644 --- a/test_api/services/user.py +++ b/test_api/services/user.py @@ -1,5 +1,3 @@ -from uuid import UUID - from passlib.context import CryptContext from ..schemas.user_schema import UserDBDTO, UserWriteDTO @@ -11,11 +9,6 @@ class UserService: self.uow = uow self.crypto_context = CryptContext(schemes="bcrypt") - async def get_all_users(self): - async with self.uow: - res = await self.uow.users.find_all() - return res - async def add_one(self, data: UserWriteDTO): new_data = data.model_dump() new_data["hashed_password"] = self.crypto_context.hash(new_data.pop("password")) @@ -24,21 +17,12 @@ class UserService: res = await self.uow.users.add_one(data=dataf) return res - async def get_user(self, id: UUID): + async def get_user(self, email: str): async with self.uow: - res = await self.uow.users.find_one(filter={"id": id}) - return res - - async def patch_one(self, id: UUID, data: UserWriteDTO): - new_data = data.model_dump() - new_data["hashed_password"] = self.crypto_context.hash(new_data.pop("password")) - dataf = UserDBDTO(**new_data) + user = await self.uow.users.find_one(filter={"email": email}) + return user + async def create_refer_code(self, email: str): async with self.uow: - res = await self.uow.users.patch_one(filter={"id": id}, data=dataf) - return res - - async def delete_user(self, id: UUID) -> None | str: - async with self.uow: - res = await self.uow.users.delete_one(filter={"id": id}) - return res + user = await self.uow.users.find_one(filter={"email": email}) + print(user)