From 665d4802f630c20a72da9a2f07ef9c8e81577a6b Mon Sep 17 00:00:00 2001 From: pi3c Date: Thu, 21 Mar 2024 01:52:55 +0300 Subject: [PATCH] bearer auth --- poetry.lock | 16 ++++++- pyproject.toml | 1 + test_api/app.py | 2 + test_api/di.py | 10 +++++ test_api/repositories/user.py | 4 +- test_api/routers/auth.py | 17 ++++++++ test_api/routers/user.py | 7 ++- test_api/schemas/__init__.py | 3 +- test_api/schemas/token.py | 6 +++ test_api/services/auth.py | 82 +++++++++++++++++++++++++++++++++++ test_api/services/user.py | 12 +++-- 11 files changed, 148 insertions(+), 12 deletions(-) create mode 100644 test_api/routers/auth.py create mode 100644 test_api/schemas/token.py create mode 100644 test_api/services/auth.py diff --git a/poetry.lock b/poetry.lock index 6114747..5eae5ac 100644 --- a/poetry.lock +++ b/poetry.lock @@ -896,6 +896,20 @@ cryptography = ["cryptography (>=3.4.0)"] pycrypto = ["pyasn1", "pycrypto (>=2.6.0,<2.7.0)"] pycryptodome = ["pyasn1", "pycryptodome (>=3.3.1,<4.0.0)"] +[[package]] +name = "python-multipart" +version = "0.0.9" +description = "A streaming multipart parser for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "python_multipart-0.0.9-py3-none-any.whl", hash = "sha256:97ca7b8ea7b05f977dc3849c3ba99d51689822fab725c3703af7c866a0c2b215"}, + {file = "python_multipart-0.0.9.tar.gz", hash = "sha256:03f54688c663f1b7977105f021043b0793151e4cb1c1a9d4a11fc13d622c4026"}, +] + +[package.extras] +dev = ["atomicwrites (==1.4.1)", "attrs (==23.2.0)", "coverage (==7.4.1)", "hatch", "invoke (==2.2.0)", "more-itertools (==10.2.0)", "pbr (==6.0.0)", "pluggy (==1.4.0)", "py (==1.11.0)", "pytest (==8.0.0)", "pytest-cov (==4.1.0)", "pytest-timeout (==2.2.0)", "pyyaml (==6.0.1)", "ruff (==0.2.1)"] + [[package]] name = "pyyaml" version = "6.0.1" @@ -1164,4 +1178,4 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "c02193dcfb120c410c23aaea101e98c092a1071e833c2196540f86797e6f3baa" +content-hash = "96e2371cfbd3ed8dab04ad2090bf982861dad2c1916834d22c835f7468c4c8c1" diff --git a/pyproject.toml b/pyproject.toml index e3b2dee..e019c78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ asyncpg = "^0.29.0" passlib = "^1.7.4" pyjwt = "^2.8.0" python-jose = {extras = ["cryptography"], version = "^3.3.0"} +python-multipart = "^0.0.9" [tool.poetry.scripts] api = "test_api.run:run_server" diff --git a/test_api/app.py b/test_api/app.py index 3d1ae36..11abba3 100644 --- a/test_api/app.py +++ b/test_api/app.py @@ -2,6 +2,7 @@ from contextlib import asynccontextmanager from fastapi import FastAPI +from .routers.auth import router as auth_router from .routers.user import router as user_router @@ -15,4 +16,5 @@ async def lifespan(app: FastAPI): def create_app() -> FastAPI: app = FastAPI(lifespan=lifespan) app.include_router(user_router) + app.include_router(auth_router) return app diff --git a/test_api/di.py b/test_api/di.py index 3e2edbc..f6c3ef7 100644 --- a/test_api/di.py +++ b/test_api/di.py @@ -1,6 +1,8 @@ from fastapi.security import OAuth2PasswordBearer from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from test_api.services.auth import AuthService + from .config import get_settings from .services.user import UserService from .uow.uow_base import UnitOfWork @@ -26,6 +28,14 @@ user_service = UserService( uow=uow, ) +auth_service = AuthService( + uow=uow, +) + def get_user_service(): return user_service + + +def get_auth_service(): + return auth_service diff --git a/test_api/repositories/user.py b/test_api/repositories/user.py index e5cbb3a..c888e90 100644 --- a/test_api/repositories/user.py +++ b/test_api/repositories/user.py @@ -2,7 +2,7 @@ from sqlalchemy import delete, insert, select, update from sqlalchemy.ext.asyncio.session import AsyncSession from ..models import UserModel -from ..schemas.user_schema import UserDBDTO, UserReadDTO, UserWriteDTO +from ..schemas.user_schema import UserDBDTO, UserReadDTO class UserRepository: @@ -33,7 +33,7 @@ class UserRepository: return UserReadDTO.model_validate(res) return None - async def patch_one(self, filter: dict, data: UserWriteDTO) -> UserReadDTO | None: + async def patch_one(self, filter: dict, data: UserDBDTO) -> UserReadDTO | None: stmt = update(UserModel).where(UserModel.id == filter["id"]).values(**data.model_dump()).returning(UserModel) res = await self.session.execute(stmt) res = res.scalar_one_or_none() diff --git a/test_api/routers/auth.py b/test_api/routers/auth.py new file mode 100644 index 0000000..9481950 --- /dev/null +++ b/test_api/routers/auth.py @@ -0,0 +1,17 @@ +from fastapi import APIRouter, Depends +from fastapi.security import OAuth2PasswordRequestForm + +from test_api.schemas.token import TokenSchema +from test_api.services.auth import AuthService + +from ..di import get_auth_service + +router = APIRouter(prefix="/token") + + +@router.post("", response_model=TokenSchema) +async def authenticate( + login: OAuth2PasswordRequestForm = Depends(), + auth: AuthService = Depends(get_auth_service), +) -> TokenSchema | None: + return await auth.authenticate(login) diff --git a/test_api/routers/user.py b/test_api/routers/user.py index bded8f8..b10e713 100644 --- a/test_api/routers/user.py +++ b/test_api/routers/user.py @@ -2,8 +2,10 @@ from uuid import UUID from fastapi import APIRouter, Depends, HTTPException +from test_api.services.auth import get_current_user + from ..di import get_user_service -from ..schemas import UserReadDTO +from ..schemas import UserDTO, UserReadDTO from ..schemas.user_schema import UserWriteDTO from ..services import UserService @@ -19,7 +21,9 @@ router = APIRouter( ) async def get_user_list( user_service: UserService = Depends(get_user_service), + user: UserDTO = Depends(get_current_user), ) -> list[UserReadDTO] | str: + print(user) return await user_service.get_all_users() @@ -50,6 +54,7 @@ async def patch_user( user_uuid: UUID, data: UserWriteDTO, user_service: UserService = Depends(get_user_service), + user: UserDTO = Depends(get_current_user), ) -> UserReadDTO | dict: res = await user_service.patch_one(id=user_uuid, data=data) if not isinstance(res, UserReadDTO): diff --git a/test_api/schemas/__init__.py b/test_api/schemas/__init__.py index 4b99d40..1eb998c 100644 --- a/test_api/schemas/__init__.py +++ b/test_api/schemas/__init__.py @@ -1,10 +1,11 @@ from .base_schema import BaseDTO, ReadDTO, WriteDTO -from .user_schema import UserReadDTO, UserWriteDTO +from .user_schema import UserDTO, UserReadDTO, UserWriteDTO __all__ = ( "BaseDTO", "WriteDTO", "ReadDTO", + "UserDTO", "UserWriteDTO", "UserReadDTO", ) diff --git a/test_api/schemas/token.py b/test_api/schemas/token.py new file mode 100644 index 0000000..7ea6f71 --- /dev/null +++ b/test_api/schemas/token.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class TokenSchema(BaseModel): + access_token: str + token_type: str diff --git a/test_api/services/auth.py b/test_api/services/auth.py new file mode 100644 index 0000000..9594bc0 --- /dev/null +++ b/test_api/services/auth.py @@ -0,0 +1,82 @@ +from datetime import datetime, timedelta + +from fastapi import Depends, HTTPException +from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm +from jose import JWTError, jwt +from passlib.context import CryptContext + +from test_api.schemas.token import TokenSchema +from test_api.schemas.user_schema import UserDTO +from test_api.uow.uow_base import UnitOfWork + +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") + +oauth2_schema = OAuth2PasswordBearer(tokenUrl="/token") + + +async def get_current_user(token: str = Depends(oauth2_schema)): + if token is None: + raise HTTPException(401, "Invalid credentials") + + try: + payload = jwt.decode(token, "fsgddfsgdfgs", algorithms=["HS256"]) + + name: str | None = payload.get("name") + sub: str | None = payload.get("sub") + expires_at: str | None = payload.get("expires_at") + + if sub is None: + raise HTTPException(401, "Invalid credentials") + + if expires_at is not None: + if is_expired(expires_at): + raise HTTPException(401, "Invalid credentials") + + return UserDTO(name=name, email=sub) + except JWTError: + raise HTTPException(401, "Invalid credentials") + + +def is_expired(expires_at: str) -> bool: + """Return :obj:`True` if token has expired.""" + + return datetime.strptime(expires_at, "%Y-%m-%d %H:%M:%S") < datetime.utcnow() + + +class AuthService: + def __init__(self, uow: UnitOfWork) -> None: + self.uow = uow + self.crypto_context = CryptContext(schemes="bcrypt") + + async def authenticate(self, login: OAuth2PasswordRequestForm = Depends()) -> TokenSchema | None: + print(login) + + async with self.uow: + user = await self.uow.users.find_one(filter={"email": login.username}) + + if user.hashed_password is None: + raise HTTPException(401, "Incorrect password") + else: + if not self.crypto_context.verify(login.password, user.hashed_password): + raise HTTPException(401, "Incorrect password") + else: + access_token = self._create_access_token(user.name, user.email) + return TokenSchema(access_token=access_token, token_type="bearer") + + def _create_access_token(self, name: str, email: str) -> str: + """Encode user information and expiration time.""" + + payload = { + "name": name, + "sub": email, + "expires_at": self._expiration_time(), + } + + return jwt.encode(payload, "fsgddfsgdfgs", algorithm="HS256") + + @staticmethod + def _expiration_time() -> str: + """Get token expiration time.""" + + expires_at = datetime.utcnow() + timedelta(minutes=30) + return expires_at.strftime("%Y-%m-%d %H:%M:%S") diff --git a/test_api/services/user.py b/test_api/services/user.py index 5161d02..df1c025 100644 --- a/test_api/services/user.py +++ b/test_api/services/user.py @@ -26,9 +26,7 @@ class UserService: async def add_one(self, data: UserWriteDTO): new_data = data.model_dump() - print(new_data) new_data["hashed_password"] = self.crypto_context.hash(new_data.pop("password")) - print(new_data) dataf: UserDBDTO = UserDBDTO(**new_data) async with self.uow: try: @@ -40,8 +38,6 @@ class UserService: else: await self.uow.commit() finally: - print(res) - return res async def get_user(self, id: UUID): @@ -51,15 +47,17 @@ class UserService: except IntegrityError as e: await self.uow.rollback() res = e._message() - else: - await self.uow.commit() finally: return res async def patch_one(self, id: UUID, data: UserWriteDTO): + new_data = data.model_dump() + new_data["hashed_password"] = self.crypto_context.hash(new_data.pop("password")) + dataf = UserDBDTO(**new_data) + async with self.uow: try: - res = await self.uow.users.patch_one(filter={"id": id}, data=data) + res = await self.uow.users.patch_one(filter={"id": id}, data=dataf) except IntegrityError as e: await self.uow.rollback() res = e._message()