main
Сергей Ванюшкин 2024-03-23 05:00:03 +03:00
parent 0d9d69d206
commit 35bb1c5fb5
18 changed files with 144 additions and 47 deletions

View File

@ -1,9 +1,7 @@
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from test_api.services.auth import AuthService
from .config import get_settings from .config import get_settings
from .services.user import UserService from .services import AuthService, UserService
from .uow.uow_base import UnitOfWork from .uow.uow_base import UnitOfWork
async_engine = create_async_engine( async_engine = create_async_engine(

View File

@ -0,0 +1,40 @@
"""empty message
Revision ID: 9d8cc4d3d62f
Revises: 41b339082609
Create Date: 2024-03-23 03:08:05.657592
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "9d8cc4d3d62f"
down_revision: str | None = "41b339082609"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"refers",
sa.Column("referer", sa.UUID(), nullable=False),
sa.Column("is_active", sa.Boolean(), nullable=False),
sa.Column("id", sa.UUID(), nullable=False),
sa.ForeignKeyConstraint(
["referer"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("refers")
# ### end Alembic commands ###

View File

@ -1,7 +1,9 @@
from .base_model import Base from .base_model import Base
from .ref_model import RefModel
from .user_model import UserModel from .user_model import UserModel
__all__ = ( __all__ = (
"Base", "Base",
"UserModel", "UserModel",
"RefModel",
) )

View File

@ -0,0 +1,18 @@
from uuid import UUID
from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from . import Base
class RefModel(Base):
__tablename__ = "refers"
referer: Mapped[UUID] = mapped_column(ForeignKey("user.id"))
referals: Mapped[list["UserModel"]] = relationship(
"UserModel",
backref="refers",
lazy="selectin",
)
is_active: Mapped[bool]

View File

@ -1,14 +0,0 @@
from uuid import UUID
from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped, mapped_column
from . import Base
class ReferModel(Base):
__tablename__ = "referals"
owner: Mapped[UUID] = mapped_column(ForeignKey("user.id"))
email: Mapped[str] = mapped_column(unique=True)
hashed_password: Mapped[str]

View File

@ -1,3 +1,7 @@
from .ref_repo import RefRepository
from .user import UserRepository from .user import UserRepository
__all__ = ("UserRepository",) __all__ = (
"UserRepository",
"RefRepository",
)

View File

@ -0,0 +1,26 @@
from uuid import UUID
from sqlalchemy import insert
from sqlalchemy.ext.asyncio.session import AsyncSession
from ..models import RefModel
from ..schemas import RefReadDTO
class RefRepository:
def __init__(self, session: AsyncSession):
self.session = session
async def add_ref(self, referer: UUID) -> RefReadDTO:
stmt = insert(RefModel).values(referer=referer, is_active=True).returning(RefModel)
res = await self.session.execute(stmt)
return RefReadDTO.model_validate(res.scalar_one())
# async def delete_ref(self, filter: dict) -> UserReadDTO | None:
# stmt = select(UserModel).filter_by(**filter)
# res = await self.session.execute(stmt)
# res = res.scalar_one_or_none()
#
# if res is not None:
# return UserReadDTO.model_validate(res)
# return None

View File

@ -1,12 +1,9 @@
from fastapi import APIRouter, Depends, HTTPException from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm
from test_api.schemas.token import TokenSchema
from test_api.schemas.user_schema import UserReadDTO, UserWriteDTO
from test_api.services.auth import AuthService
from test_api.services.user import UserService
from ..di import get_auth_service, get_user_service from ..di import get_auth_service, get_user_service
from ..schemas import TokenSchema, UserReadDTO, UserWriteDTO
from ..services import AuthService, UserService
router = APIRouter() router = APIRouter()

View File

@ -1,10 +1,8 @@
from fastapi import APIRouter, Depends from fastapi import APIRouter, Depends
from test_api.services.auth import get_current_user
from ..di import get_user_service from ..di import get_user_service
from ..schemas import RefDTO, UserDTO, UserReadDTO from ..schemas import RefReadDTO, UserAuthDTO, UserDTO, UserReadDTO
from ..services import UserService from ..services import UserService, get_current_user
router = APIRouter( router = APIRouter(
prefix="/user", prefix="/user",
@ -19,9 +17,12 @@ async def get_user(
return await user_service.get_user(email=user.email) return await user_service.get_user(email=user.email)
@router.get("/me/create_ref", status_code=201, response_model=RefDTO) @router.get("/me/create_ref", status_code=201, response_model=RefReadDTO)
async def create_refer_code( async def create_refer_code(
user_service: UserService = Depends(get_user_service), user_service: UserService = Depends(get_user_service),
user: UserDTO = Depends(get_current_user), user: UserAuthDTO = Depends(get_current_user),
) -> RefDTO | None: ):
return await user_service.create_refer_code(email=user.email) print(user)
res = await user_service.create_ref(email=user.email)
return res

View File

@ -1,13 +1,18 @@
from .base_schema import BaseDTO, ReadDTO, WriteDTO from .base_schema import BaseDTO, ReadDTO, WriteDTO
from .ref import RefDTO from .ref_schema import RefDTO, RefReadDTO
from .user_schema import UserDTO, UserReadDTO, UserWriteDTO from .token_schema import TokenSchema
from .user_schema import UserAuthDTO, UserDBDTO, UserDTO, UserReadDTO, UserWriteDTO
__all__ = ( __all__ = (
"BaseDTO", "BaseDTO",
"WriteDTO", "WriteDTO",
"ReadDTO", "ReadDTO",
"UserDTO", "UserDTO",
"UserAuthDTO",
"UserWriteDTO", "UserWriteDTO",
"UserReadDTO", "UserReadDTO",
"UserDBDTO",
"RefDTO", "RefDTO",
"RefReadDTO",
"TokenSchema",
) )

View File

@ -1,5 +0,0 @@
from pydantic import BaseModel
class RefDTO(BaseModel):
refer_code: str

View File

@ -0,0 +1,11 @@
from uuid import UUID
from . import BaseDTO
class RefDTO(BaseDTO):
referer: str
class RefReadDTO(RefDTO):
refer: UUID

View File

@ -6,6 +6,10 @@ class UserDTO(WriteDTO):
email: str email: str
class UserAuthDTO(UserDTO):
id: str
class UserWriteDTO(UserDTO): class UserWriteDTO(UserDTO):
password: str password: str

View File

@ -1,3 +1,8 @@
from .auth_service import AuthService, get_current_user
from .user import UserService from .user import UserService
__all__ = ("UserService",) __all__ = (
"UserService",
"AuthService",
"get_current_user",
)

View File

@ -1,14 +1,15 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from uuid import UUID
from fastapi import Depends, HTTPException from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt from jose import JWTError, jwt
from passlib.context import CryptContext from passlib.context import CryptContext
from test_api.schemas.token import TokenSchema
from test_api.schemas.user_schema import UserDTO
from test_api.uow.uow_base import UnitOfWork from test_api.uow.uow_base import UnitOfWork
from ..schemas import TokenSchema, UserAuthDTO
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_schema = OAuth2PasswordBearer(tokenUrl="/token") oauth2_schema = OAuth2PasswordBearer(tokenUrl="/token")
@ -21,6 +22,7 @@ async def get_current_user(token: str = Depends(oauth2_schema)):
try: try:
payload = jwt.decode(token, "fsgddfsgdfgs", algorithms=["HS256"]) payload = jwt.decode(token, "fsgddfsgdfgs", algorithms=["HS256"])
id: str = payload.get("id", "")
name: str = payload.get("name", "") name: str = payload.get("name", "")
sub: str = payload.get("sub", "") sub: str = payload.get("sub", "")
expires_at: str = payload.get("expires_at", "") expires_at: str = payload.get("expires_at", "")
@ -32,7 +34,7 @@ async def get_current_user(token: str = Depends(oauth2_schema)):
if is_expired(expires_at): if is_expired(expires_at):
raise HTTPException(401, "Invalid credentials") raise HTTPException(401, "Invalid credentials")
return UserDTO(name=name, email=sub) return UserAuthDTO(id=id, name=name, email=sub)
except JWTError: except JWTError:
raise HTTPException(401, "Invalid credentials") raise HTTPException(401, "Invalid credentials")
@ -58,13 +60,14 @@ class AuthService:
if not self.crypto_context.verify(login.password, user.hashed_password): if not self.crypto_context.verify(login.password, user.hashed_password):
raise HTTPException(401, "Incorrect password") raise HTTPException(401, "Incorrect password")
else: else:
access_token = self._create_access_token(user.name, user.email) access_token = self._create_access_token(user.id, user.name, user.email)
return TokenSchema(access_token=access_token, token_type="bearer") return TokenSchema(access_token=access_token, token_type="bearer")
def _create_access_token(self, name: str, email: str) -> str: def _create_access_token(self, id: UUID, name: str, email: str) -> str:
"""Encode user information and expiration time.""" """Encode user information and expiration time."""
payload = { payload = {
"id": str(id),
"name": name, "name": name,
"sub": email, "sub": email,
"expires_at": self._expiration_time(), "expires_at": self._expiration_time(),

View File

@ -22,7 +22,8 @@ class UserService:
user = await self.uow.users.find_one(filter={"email": email}) user = await self.uow.users.find_one(filter={"email": email})
return user return user
async def create_refer_code(self, email: str): async def create_ref(self, email: str):
async with self.uow: async with self.uow:
user = await self.uow.users.find_one(filter={"email": email}) user = await self.uow.users.find_one(filter={"email": email})
print(user) res = await self.uow.ref.add_ref(referer=user.id)
return res

View File

@ -2,13 +2,14 @@ from types import TracebackType
from fastapi import HTTPException from fastapi import HTTPException
from ..repositories import UserRepository from ..repositories import RefRepository, UserRepository
class UnitOfWork: class UnitOfWork:
def __init__(self, session_factory): def __init__(self, session_factory):
self.session = session_factory() self.session = session_factory()
self.users = UserRepository(self.session) self.users = UserRepository(self.session)
self.ref = RefRepository(self.session)
async def __aenter__(self): async def __aenter__(self):
return self return self