bearer auth

main
Сергей Ванюшкин 2024-03-21 01:52:55 +03:00
parent 30de12fba7
commit 665d4802f6
11 changed files with 148 additions and 12 deletions

16
poetry.lock generated
View File

@ -896,6 +896,20 @@ cryptography = ["cryptography (>=3.4.0)"]
pycrypto = ["pyasn1", "pycrypto (>=2.6.0,<2.7.0)"] pycrypto = ["pyasn1", "pycrypto (>=2.6.0,<2.7.0)"]
pycryptodome = ["pyasn1", "pycryptodome (>=3.3.1,<4.0.0)"] pycryptodome = ["pyasn1", "pycryptodome (>=3.3.1,<4.0.0)"]
[[package]]
name = "python-multipart"
version = "0.0.9"
description = "A streaming multipart parser for Python"
optional = false
python-versions = ">=3.8"
files = [
{file = "python_multipart-0.0.9-py3-none-any.whl", hash = "sha256:97ca7b8ea7b05f977dc3849c3ba99d51689822fab725c3703af7c866a0c2b215"},
{file = "python_multipart-0.0.9.tar.gz", hash = "sha256:03f54688c663f1b7977105f021043b0793151e4cb1c1a9d4a11fc13d622c4026"},
]
[package.extras]
dev = ["atomicwrites (==1.4.1)", "attrs (==23.2.0)", "coverage (==7.4.1)", "hatch", "invoke (==2.2.0)", "more-itertools (==10.2.0)", "pbr (==6.0.0)", "pluggy (==1.4.0)", "py (==1.11.0)", "pytest (==8.0.0)", "pytest-cov (==4.1.0)", "pytest-timeout (==2.2.0)", "pyyaml (==6.0.1)", "ruff (==0.2.1)"]
[[package]] [[package]]
name = "pyyaml" name = "pyyaml"
version = "6.0.1" version = "6.0.1"
@ -1164,4 +1178,4 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess
[metadata] [metadata]
lock-version = "2.0" lock-version = "2.0"
python-versions = "^3.10" python-versions = "^3.10"
content-hash = "c02193dcfb120c410c23aaea101e98c092a1071e833c2196540f86797e6f3baa" content-hash = "96e2371cfbd3ed8dab04ad2090bf982861dad2c1916834d22c835f7468c4c8c1"

View File

@ -12,6 +12,7 @@ asyncpg = "^0.29.0"
passlib = "^1.7.4" passlib = "^1.7.4"
pyjwt = "^2.8.0" pyjwt = "^2.8.0"
python-jose = {extras = ["cryptography"], version = "^3.3.0"} python-jose = {extras = ["cryptography"], version = "^3.3.0"}
python-multipart = "^0.0.9"
[tool.poetry.scripts] [tool.poetry.scripts]
api = "test_api.run:run_server" api = "test_api.run:run_server"

View File

@ -2,6 +2,7 @@ from contextlib import asynccontextmanager
from fastapi import FastAPI from fastapi import FastAPI
from .routers.auth import router as auth_router
from .routers.user import router as user_router from .routers.user import router as user_router
@ -15,4 +16,5 @@ async def lifespan(app: FastAPI):
def create_app() -> FastAPI: def create_app() -> FastAPI:
app = FastAPI(lifespan=lifespan) app = FastAPI(lifespan=lifespan)
app.include_router(user_router) app.include_router(user_router)
app.include_router(auth_router)
return app return app

View File

@ -1,6 +1,8 @@
from fastapi.security import OAuth2PasswordBearer from fastapi.security import OAuth2PasswordBearer
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from test_api.services.auth import AuthService
from .config import get_settings from .config import get_settings
from .services.user import UserService from .services.user import UserService
from .uow.uow_base import UnitOfWork from .uow.uow_base import UnitOfWork
@ -26,6 +28,14 @@ user_service = UserService(
uow=uow, uow=uow,
) )
auth_service = AuthService(
uow=uow,
)
def get_user_service(): def get_user_service():
return user_service return user_service
def get_auth_service():
return auth_service

View File

@ -2,7 +2,7 @@ from sqlalchemy import delete, insert, select, update
from sqlalchemy.ext.asyncio.session import AsyncSession from sqlalchemy.ext.asyncio.session import AsyncSession
from ..models import UserModel from ..models import UserModel
from ..schemas.user_schema import UserDBDTO, UserReadDTO, UserWriteDTO from ..schemas.user_schema import UserDBDTO, UserReadDTO
class UserRepository: class UserRepository:
@ -33,7 +33,7 @@ class UserRepository:
return UserReadDTO.model_validate(res) return UserReadDTO.model_validate(res)
return None return None
async def patch_one(self, filter: dict, data: UserWriteDTO) -> UserReadDTO | None: async def patch_one(self, filter: dict, data: UserDBDTO) -> UserReadDTO | None:
stmt = update(UserModel).where(UserModel.id == filter["id"]).values(**data.model_dump()).returning(UserModel) stmt = update(UserModel).where(UserModel.id == filter["id"]).values(**data.model_dump()).returning(UserModel)
res = await self.session.execute(stmt) res = await self.session.execute(stmt)
res = res.scalar_one_or_none() res = res.scalar_one_or_none()

17
test_api/routers/auth.py Normal file
View File

@ -0,0 +1,17 @@
from fastapi import APIRouter, Depends
from fastapi.security import OAuth2PasswordRequestForm
from test_api.schemas.token import TokenSchema
from test_api.services.auth import AuthService
from ..di import get_auth_service
router = APIRouter(prefix="/token")
@router.post("", response_model=TokenSchema)
async def authenticate(
login: OAuth2PasswordRequestForm = Depends(),
auth: AuthService = Depends(get_auth_service),
) -> TokenSchema | None:
return await auth.authenticate(login)

View File

@ -2,8 +2,10 @@ from uuid import UUID
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from test_api.services.auth import get_current_user
from ..di import get_user_service from ..di import get_user_service
from ..schemas import UserReadDTO from ..schemas import UserDTO, UserReadDTO
from ..schemas.user_schema import UserWriteDTO from ..schemas.user_schema import UserWriteDTO
from ..services import UserService from ..services import UserService
@ -19,7 +21,9 @@ router = APIRouter(
) )
async def get_user_list( async def get_user_list(
user_service: UserService = Depends(get_user_service), user_service: UserService = Depends(get_user_service),
user: UserDTO = Depends(get_current_user),
) -> list[UserReadDTO] | str: ) -> list[UserReadDTO] | str:
print(user)
return await user_service.get_all_users() return await user_service.get_all_users()
@ -50,6 +54,7 @@ async def patch_user(
user_uuid: UUID, user_uuid: UUID,
data: UserWriteDTO, data: UserWriteDTO,
user_service: UserService = Depends(get_user_service), user_service: UserService = Depends(get_user_service),
user: UserDTO = Depends(get_current_user),
) -> UserReadDTO | dict: ) -> UserReadDTO | dict:
res = await user_service.patch_one(id=user_uuid, data=data) res = await user_service.patch_one(id=user_uuid, data=data)
if not isinstance(res, UserReadDTO): if not isinstance(res, UserReadDTO):

View File

@ -1,10 +1,11 @@
from .base_schema import BaseDTO, ReadDTO, WriteDTO from .base_schema import BaseDTO, ReadDTO, WriteDTO
from .user_schema import UserReadDTO, UserWriteDTO from .user_schema import UserDTO, UserReadDTO, UserWriteDTO
__all__ = ( __all__ = (
"BaseDTO", "BaseDTO",
"WriteDTO", "WriteDTO",
"ReadDTO", "ReadDTO",
"UserDTO",
"UserWriteDTO", "UserWriteDTO",
"UserReadDTO", "UserReadDTO",
) )

View File

@ -0,0 +1,6 @@
from pydantic import BaseModel
class TokenSchema(BaseModel):
access_token: str
token_type: str

82
test_api/services/auth.py Normal file
View File

@ -0,0 +1,82 @@
from datetime import datetime, timedelta
from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from passlib.context import CryptContext
from test_api.schemas.token import TokenSchema
from test_api.schemas.user_schema import UserDTO
from test_api.uow.uow_base import UnitOfWork
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_schema = OAuth2PasswordBearer(tokenUrl="/token")
async def get_current_user(token: str = Depends(oauth2_schema)):
if token is None:
raise HTTPException(401, "Invalid credentials")
try:
payload = jwt.decode(token, "fsgddfsgdfgs", algorithms=["HS256"])
name: str | None = payload.get("name")
sub: str | None = payload.get("sub")
expires_at: str | None = payload.get("expires_at")
if sub is None:
raise HTTPException(401, "Invalid credentials")
if expires_at is not None:
if is_expired(expires_at):
raise HTTPException(401, "Invalid credentials")
return UserDTO(name=name, email=sub)
except JWTError:
raise HTTPException(401, "Invalid credentials")
def is_expired(expires_at: str) -> bool:
"""Return :obj:`True` if token has expired."""
return datetime.strptime(expires_at, "%Y-%m-%d %H:%M:%S") < datetime.utcnow()
class AuthService:
def __init__(self, uow: UnitOfWork) -> None:
self.uow = uow
self.crypto_context = CryptContext(schemes="bcrypt")
async def authenticate(self, login: OAuth2PasswordRequestForm = Depends()) -> TokenSchema | None:
print(login)
async with self.uow:
user = await self.uow.users.find_one(filter={"email": login.username})
if user.hashed_password is None:
raise HTTPException(401, "Incorrect password")
else:
if not self.crypto_context.verify(login.password, user.hashed_password):
raise HTTPException(401, "Incorrect password")
else:
access_token = self._create_access_token(user.name, user.email)
return TokenSchema(access_token=access_token, token_type="bearer")
def _create_access_token(self, name: str, email: str) -> str:
"""Encode user information and expiration time."""
payload = {
"name": name,
"sub": email,
"expires_at": self._expiration_time(),
}
return jwt.encode(payload, "fsgddfsgdfgs", algorithm="HS256")
@staticmethod
def _expiration_time() -> str:
"""Get token expiration time."""
expires_at = datetime.utcnow() + timedelta(minutes=30)
return expires_at.strftime("%Y-%m-%d %H:%M:%S")

View File

@ -26,9 +26,7 @@ class UserService:
async def add_one(self, data: UserWriteDTO): async def add_one(self, data: UserWriteDTO):
new_data = data.model_dump() new_data = data.model_dump()
print(new_data)
new_data["hashed_password"] = self.crypto_context.hash(new_data.pop("password")) new_data["hashed_password"] = self.crypto_context.hash(new_data.pop("password"))
print(new_data)
dataf: UserDBDTO = UserDBDTO(**new_data) dataf: UserDBDTO = UserDBDTO(**new_data)
async with self.uow: async with self.uow:
try: try:
@ -40,8 +38,6 @@ class UserService:
else: else:
await self.uow.commit() await self.uow.commit()
finally: finally:
print(res)
return res return res
async def get_user(self, id: UUID): async def get_user(self, id: UUID):
@ -51,15 +47,17 @@ class UserService:
except IntegrityError as e: except IntegrityError as e:
await self.uow.rollback() await self.uow.rollback()
res = e._message() res = e._message()
else:
await self.uow.commit()
finally: finally:
return res return res
async def patch_one(self, id: UUID, data: UserWriteDTO): async def patch_one(self, id: UUID, data: UserWriteDTO):
new_data = data.model_dump()
new_data["hashed_password"] = self.crypto_context.hash(new_data.pop("password"))
dataf = UserDBDTO(**new_data)
async with self.uow: async with self.uow:
try: try:
res = await self.uow.users.patch_one(filter={"id": id}, data=data) res = await self.uow.users.patch_one(filter={"id": id}, data=dataf)
except IntegrityError as e: except IntegrityError as e:
await self.uow.rollback() await self.uow.rollback()
res = e._message() res = e._message()