bearer auth
parent
30de12fba7
commit
665d4802f6
|
@ -896,6 +896,20 @@ cryptography = ["cryptography (>=3.4.0)"]
|
||||||
pycrypto = ["pyasn1", "pycrypto (>=2.6.0,<2.7.0)"]
|
pycrypto = ["pyasn1", "pycrypto (>=2.6.0,<2.7.0)"]
|
||||||
pycryptodome = ["pyasn1", "pycryptodome (>=3.3.1,<4.0.0)"]
|
pycryptodome = ["pyasn1", "pycryptodome (>=3.3.1,<4.0.0)"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "python-multipart"
|
||||||
|
version = "0.0.9"
|
||||||
|
description = "A streaming multipart parser for Python"
|
||||||
|
optional = false
|
||||||
|
python-versions = ">=3.8"
|
||||||
|
files = [
|
||||||
|
{file = "python_multipart-0.0.9-py3-none-any.whl", hash = "sha256:97ca7b8ea7b05f977dc3849c3ba99d51689822fab725c3703af7c866a0c2b215"},
|
||||||
|
{file = "python_multipart-0.0.9.tar.gz", hash = "sha256:03f54688c663f1b7977105f021043b0793151e4cb1c1a9d4a11fc13d622c4026"},
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.extras]
|
||||||
|
dev = ["atomicwrites (==1.4.1)", "attrs (==23.2.0)", "coverage (==7.4.1)", "hatch", "invoke (==2.2.0)", "more-itertools (==10.2.0)", "pbr (==6.0.0)", "pluggy (==1.4.0)", "py (==1.11.0)", "pytest (==8.0.0)", "pytest-cov (==4.1.0)", "pytest-timeout (==2.2.0)", "pyyaml (==6.0.1)", "ruff (==0.2.1)"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pyyaml"
|
name = "pyyaml"
|
||||||
version = "6.0.1"
|
version = "6.0.1"
|
||||||
|
@ -1164,4 +1178,4 @@ test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess
|
||||||
[metadata]
|
[metadata]
|
||||||
lock-version = "2.0"
|
lock-version = "2.0"
|
||||||
python-versions = "^3.10"
|
python-versions = "^3.10"
|
||||||
content-hash = "c02193dcfb120c410c23aaea101e98c092a1071e833c2196540f86797e6f3baa"
|
content-hash = "96e2371cfbd3ed8dab04ad2090bf982861dad2c1916834d22c835f7468c4c8c1"
|
||||||
|
|
|
@ -12,6 +12,7 @@ asyncpg = "^0.29.0"
|
||||||
passlib = "^1.7.4"
|
passlib = "^1.7.4"
|
||||||
pyjwt = "^2.8.0"
|
pyjwt = "^2.8.0"
|
||||||
python-jose = {extras = ["cryptography"], version = "^3.3.0"}
|
python-jose = {extras = ["cryptography"], version = "^3.3.0"}
|
||||||
|
python-multipart = "^0.0.9"
|
||||||
|
|
||||||
[tool.poetry.scripts]
|
[tool.poetry.scripts]
|
||||||
api = "test_api.run:run_server"
|
api = "test_api.run:run_server"
|
||||||
|
|
|
@ -2,6 +2,7 @@ from contextlib import asynccontextmanager
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from .routers.auth import router as auth_router
|
||||||
from .routers.user import router as user_router
|
from .routers.user import router as user_router
|
||||||
|
|
||||||
|
|
||||||
|
@ -15,4 +16,5 @@ async def lifespan(app: FastAPI):
|
||||||
def create_app() -> FastAPI:
|
def create_app() -> FastAPI:
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(lifespan=lifespan)
|
||||||
app.include_router(user_router)
|
app.include_router(user_router)
|
||||||
|
app.include_router(auth_router)
|
||||||
return app
|
return app
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from test_api.services.auth import AuthService
|
||||||
|
|
||||||
from .config import get_settings
|
from .config import get_settings
|
||||||
from .services.user import UserService
|
from .services.user import UserService
|
||||||
from .uow.uow_base import UnitOfWork
|
from .uow.uow_base import UnitOfWork
|
||||||
|
@ -26,6 +28,14 @@ user_service = UserService(
|
||||||
uow=uow,
|
uow=uow,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
auth_service = AuthService(
|
||||||
|
uow=uow,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_user_service():
|
def get_user_service():
|
||||||
return user_service
|
return user_service
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_service():
|
||||||
|
return auth_service
|
||||||
|
|
|
@ -2,7 +2,7 @@ from sqlalchemy import delete, insert, select, update
|
||||||
from sqlalchemy.ext.asyncio.session import AsyncSession
|
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
from ..models import UserModel
|
from ..models import UserModel
|
||||||
from ..schemas.user_schema import UserDBDTO, UserReadDTO, UserWriteDTO
|
from ..schemas.user_schema import UserDBDTO, UserReadDTO
|
||||||
|
|
||||||
|
|
||||||
class UserRepository:
|
class UserRepository:
|
||||||
|
@ -33,7 +33,7 @@ class UserRepository:
|
||||||
return UserReadDTO.model_validate(res)
|
return UserReadDTO.model_validate(res)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def patch_one(self, filter: dict, data: UserWriteDTO) -> UserReadDTO | None:
|
async def patch_one(self, filter: dict, data: UserDBDTO) -> UserReadDTO | None:
|
||||||
stmt = update(UserModel).where(UserModel.id == filter["id"]).values(**data.model_dump()).returning(UserModel)
|
stmt = update(UserModel).where(UserModel.id == filter["id"]).values(**data.model_dump()).returning(UserModel)
|
||||||
res = await self.session.execute(stmt)
|
res = await self.session.execute(stmt)
|
||||||
res = res.scalar_one_or_none()
|
res = res.scalar_one_or_none()
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
|
||||||
|
from test_api.schemas.token import TokenSchema
|
||||||
|
from test_api.services.auth import AuthService
|
||||||
|
|
||||||
|
from ..di import get_auth_service
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/token")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("", response_model=TokenSchema)
|
||||||
|
async def authenticate(
|
||||||
|
login: OAuth2PasswordRequestForm = Depends(),
|
||||||
|
auth: AuthService = Depends(get_auth_service),
|
||||||
|
) -> TokenSchema | None:
|
||||||
|
return await auth.authenticate(login)
|
|
@ -2,8 +2,10 @@ from uuid import UUID
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
|
||||||
|
from test_api.services.auth import get_current_user
|
||||||
|
|
||||||
from ..di import get_user_service
|
from ..di import get_user_service
|
||||||
from ..schemas import UserReadDTO
|
from ..schemas import UserDTO, UserReadDTO
|
||||||
from ..schemas.user_schema import UserWriteDTO
|
from ..schemas.user_schema import UserWriteDTO
|
||||||
from ..services import UserService
|
from ..services import UserService
|
||||||
|
|
||||||
|
@ -19,7 +21,9 @@ router = APIRouter(
|
||||||
)
|
)
|
||||||
async def get_user_list(
|
async def get_user_list(
|
||||||
user_service: UserService = Depends(get_user_service),
|
user_service: UserService = Depends(get_user_service),
|
||||||
|
user: UserDTO = Depends(get_current_user),
|
||||||
) -> list[UserReadDTO] | str:
|
) -> list[UserReadDTO] | str:
|
||||||
|
print(user)
|
||||||
return await user_service.get_all_users()
|
return await user_service.get_all_users()
|
||||||
|
|
||||||
|
|
||||||
|
@ -50,6 +54,7 @@ async def patch_user(
|
||||||
user_uuid: UUID,
|
user_uuid: UUID,
|
||||||
data: UserWriteDTO,
|
data: UserWriteDTO,
|
||||||
user_service: UserService = Depends(get_user_service),
|
user_service: UserService = Depends(get_user_service),
|
||||||
|
user: UserDTO = Depends(get_current_user),
|
||||||
) -> UserReadDTO | dict:
|
) -> UserReadDTO | dict:
|
||||||
res = await user_service.patch_one(id=user_uuid, data=data)
|
res = await user_service.patch_one(id=user_uuid, data=data)
|
||||||
if not isinstance(res, UserReadDTO):
|
if not isinstance(res, UserReadDTO):
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
from .base_schema import BaseDTO, ReadDTO, WriteDTO
|
from .base_schema import BaseDTO, ReadDTO, WriteDTO
|
||||||
from .user_schema import UserReadDTO, UserWriteDTO
|
from .user_schema import UserDTO, UserReadDTO, UserWriteDTO
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
"BaseDTO",
|
"BaseDTO",
|
||||||
"WriteDTO",
|
"WriteDTO",
|
||||||
"ReadDTO",
|
"ReadDTO",
|
||||||
|
"UserDTO",
|
||||||
"UserWriteDTO",
|
"UserWriteDTO",
|
||||||
"UserReadDTO",
|
"UserReadDTO",
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,6 @@
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class TokenSchema(BaseModel):
|
||||||
|
access_token: str
|
||||||
|
token_type: str
|
|
@ -0,0 +1,82 @@
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException
|
||||||
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
from passlib.context import CryptContext
|
||||||
|
|
||||||
|
from test_api.schemas.token import TokenSchema
|
||||||
|
from test_api.schemas.user_schema import UserDTO
|
||||||
|
from test_api.uow.uow_base import UnitOfWork
|
||||||
|
|
||||||
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
|
|
||||||
|
oauth2_schema = OAuth2PasswordBearer(tokenUrl="/token")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_user(token: str = Depends(oauth2_schema)):
|
||||||
|
if token is None:
|
||||||
|
raise HTTPException(401, "Invalid credentials")
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, "fsgddfsgdfgs", algorithms=["HS256"])
|
||||||
|
|
||||||
|
name: str | None = payload.get("name")
|
||||||
|
sub: str | None = payload.get("sub")
|
||||||
|
expires_at: str | None = payload.get("expires_at")
|
||||||
|
|
||||||
|
if sub is None:
|
||||||
|
raise HTTPException(401, "Invalid credentials")
|
||||||
|
|
||||||
|
if expires_at is not None:
|
||||||
|
if is_expired(expires_at):
|
||||||
|
raise HTTPException(401, "Invalid credentials")
|
||||||
|
|
||||||
|
return UserDTO(name=name, email=sub)
|
||||||
|
except JWTError:
|
||||||
|
raise HTTPException(401, "Invalid credentials")
|
||||||
|
|
||||||
|
|
||||||
|
def is_expired(expires_at: str) -> bool:
|
||||||
|
"""Return :obj:`True` if token has expired."""
|
||||||
|
|
||||||
|
return datetime.strptime(expires_at, "%Y-%m-%d %H:%M:%S") < datetime.utcnow()
|
||||||
|
|
||||||
|
|
||||||
|
class AuthService:
|
||||||
|
def __init__(self, uow: UnitOfWork) -> None:
|
||||||
|
self.uow = uow
|
||||||
|
self.crypto_context = CryptContext(schemes="bcrypt")
|
||||||
|
|
||||||
|
async def authenticate(self, login: OAuth2PasswordRequestForm = Depends()) -> TokenSchema | None:
|
||||||
|
print(login)
|
||||||
|
|
||||||
|
async with self.uow:
|
||||||
|
user = await self.uow.users.find_one(filter={"email": login.username})
|
||||||
|
|
||||||
|
if user.hashed_password is None:
|
||||||
|
raise HTTPException(401, "Incorrect password")
|
||||||
|
else:
|
||||||
|
if not self.crypto_context.verify(login.password, user.hashed_password):
|
||||||
|
raise HTTPException(401, "Incorrect password")
|
||||||
|
else:
|
||||||
|
access_token = self._create_access_token(user.name, user.email)
|
||||||
|
return TokenSchema(access_token=access_token, token_type="bearer")
|
||||||
|
|
||||||
|
def _create_access_token(self, name: str, email: str) -> str:
|
||||||
|
"""Encode user information and expiration time."""
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"name": name,
|
||||||
|
"sub": email,
|
||||||
|
"expires_at": self._expiration_time(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return jwt.encode(payload, "fsgddfsgdfgs", algorithm="HS256")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _expiration_time() -> str:
|
||||||
|
"""Get token expiration time."""
|
||||||
|
|
||||||
|
expires_at = datetime.utcnow() + timedelta(minutes=30)
|
||||||
|
return expires_at.strftime("%Y-%m-%d %H:%M:%S")
|
|
@ -26,9 +26,7 @@ class UserService:
|
||||||
|
|
||||||
async def add_one(self, data: UserWriteDTO):
|
async def add_one(self, data: UserWriteDTO):
|
||||||
new_data = data.model_dump()
|
new_data = data.model_dump()
|
||||||
print(new_data)
|
|
||||||
new_data["hashed_password"] = self.crypto_context.hash(new_data.pop("password"))
|
new_data["hashed_password"] = self.crypto_context.hash(new_data.pop("password"))
|
||||||
print(new_data)
|
|
||||||
dataf: UserDBDTO = UserDBDTO(**new_data)
|
dataf: UserDBDTO = UserDBDTO(**new_data)
|
||||||
async with self.uow:
|
async with self.uow:
|
||||||
try:
|
try:
|
||||||
|
@ -40,8 +38,6 @@ class UserService:
|
||||||
else:
|
else:
|
||||||
await self.uow.commit()
|
await self.uow.commit()
|
||||||
finally:
|
finally:
|
||||||
print(res)
|
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
async def get_user(self, id: UUID):
|
async def get_user(self, id: UUID):
|
||||||
|
@ -51,15 +47,17 @@ class UserService:
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
await self.uow.rollback()
|
await self.uow.rollback()
|
||||||
res = e._message()
|
res = e._message()
|
||||||
else:
|
|
||||||
await self.uow.commit()
|
|
||||||
finally:
|
finally:
|
||||||
return res
|
return res
|
||||||
|
|
||||||
async def patch_one(self, id: UUID, data: UserWriteDTO):
|
async def patch_one(self, id: UUID, data: UserWriteDTO):
|
||||||
|
new_data = data.model_dump()
|
||||||
|
new_data["hashed_password"] = self.crypto_context.hash(new_data.pop("password"))
|
||||||
|
dataf = UserDBDTO(**new_data)
|
||||||
|
|
||||||
async with self.uow:
|
async with self.uow:
|
||||||
try:
|
try:
|
||||||
res = await self.uow.users.patch_one(filter={"id": id}, data=data)
|
res = await self.uow.users.patch_one(filter={"id": id}, data=dataf)
|
||||||
except IntegrityError as e:
|
except IntegrityError as e:
|
||||||
await self.uow.rollback()
|
await self.uow.rollback()
|
||||||
res = e._message()
|
res = e._message()
|
||||||
|
|
Loading…
Reference in New Issue