main
Сергей Ванюшкин 2024-03-23 05:00:03 +03:00
parent 0d9d69d206
commit 35bb1c5fb5
18 changed files with 144 additions and 47 deletions

View File

@ -1,9 +1,7 @@
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from test_api.services.auth import AuthService
from .config import get_settings
from .services.user import UserService
from .services import AuthService, UserService
from .uow.uow_base import UnitOfWork
async_engine = create_async_engine(

View File

@ -0,0 +1,40 @@
"""empty message
Revision ID: 9d8cc4d3d62f
Revises: 41b339082609
Create Date: 2024-03-23 03:08:05.657592
"""
from collections.abc import Sequence
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "9d8cc4d3d62f"
down_revision: str | None = "41b339082609"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"refers",
sa.Column("referer", sa.UUID(), nullable=False),
sa.Column("is_active", sa.Boolean(), nullable=False),
sa.Column("id", sa.UUID(), nullable=False),
sa.ForeignKeyConstraint(
["referer"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("refers")
# ### end Alembic commands ###

View File

@ -1,7 +1,9 @@
from .base_model import Base
from .ref_model import RefModel
from .user_model import UserModel
__all__ = (
"Base",
"UserModel",
"RefModel",
)

View File

@ -0,0 +1,18 @@
from uuid import UUID
from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped, mapped_column, relationship
from . import Base
class RefModel(Base):
__tablename__ = "refers"
referer: Mapped[UUID] = mapped_column(ForeignKey("user.id"))
referals: Mapped[list["UserModel"]] = relationship(
"UserModel",
backref="refers",
lazy="selectin",
)
is_active: Mapped[bool]

View File

@ -1,14 +0,0 @@
from uuid import UUID
from sqlalchemy import ForeignKey
from sqlalchemy.orm import Mapped, mapped_column
from . import Base
class ReferModel(Base):
__tablename__ = "referals"
owner: Mapped[UUID] = mapped_column(ForeignKey("user.id"))
email: Mapped[str] = mapped_column(unique=True)
hashed_password: Mapped[str]

View File

@ -1,3 +1,7 @@
from .ref_repo import RefRepository
from .user import UserRepository
__all__ = ("UserRepository",)
__all__ = (
"UserRepository",
"RefRepository",
)

View File

@ -0,0 +1,26 @@
from uuid import UUID
from sqlalchemy import insert
from sqlalchemy.ext.asyncio.session import AsyncSession
from ..models import RefModel
from ..schemas import RefReadDTO
class RefRepository:
def __init__(self, session: AsyncSession):
self.session = session
async def add_ref(self, referer: UUID) -> RefReadDTO:
stmt = insert(RefModel).values(referer=referer, is_active=True).returning(RefModel)
res = await self.session.execute(stmt)
return RefReadDTO.model_validate(res.scalar_one())
# async def delete_ref(self, filter: dict) -> UserReadDTO | None:
# stmt = select(UserModel).filter_by(**filter)
# res = await self.session.execute(stmt)
# res = res.scalar_one_or_none()
#
# if res is not None:
# return UserReadDTO.model_validate(res)
# return None

View File

@ -1,12 +1,9 @@
from fastapi import APIRouter, Depends, HTTPException
from fastapi.security import OAuth2PasswordRequestForm
from test_api.schemas.token import TokenSchema
from test_api.schemas.user_schema import UserReadDTO, UserWriteDTO
from test_api.services.auth import AuthService
from test_api.services.user import UserService
from ..di import get_auth_service, get_user_service
from ..schemas import TokenSchema, UserReadDTO, UserWriteDTO
from ..services import AuthService, UserService
router = APIRouter()

View File

@ -1,10 +1,8 @@
from fastapi import APIRouter, Depends
from test_api.services.auth import get_current_user
from ..di import get_user_service
from ..schemas import RefDTO, UserDTO, UserReadDTO
from ..services import UserService
from ..schemas import RefReadDTO, UserAuthDTO, UserDTO, UserReadDTO
from ..services import UserService, get_current_user
router = APIRouter(
prefix="/user",
@ -19,9 +17,12 @@ async def get_user(
return await user_service.get_user(email=user.email)
@router.get("/me/create_ref", status_code=201, response_model=RefDTO)
@router.get("/me/create_ref", status_code=201, response_model=RefReadDTO)
async def create_refer_code(
user_service: UserService = Depends(get_user_service),
user: UserDTO = Depends(get_current_user),
) -> RefDTO | None:
return await user_service.create_refer_code(email=user.email)
user: UserAuthDTO = Depends(get_current_user),
):
print(user)
res = await user_service.create_ref(email=user.email)
return res

View File

@ -1,13 +1,18 @@
from .base_schema import BaseDTO, ReadDTO, WriteDTO
from .ref import RefDTO
from .user_schema import UserDTO, UserReadDTO, UserWriteDTO
from .ref_schema import RefDTO, RefReadDTO
from .token_schema import TokenSchema
from .user_schema import UserAuthDTO, UserDBDTO, UserDTO, UserReadDTO, UserWriteDTO
__all__ = (
"BaseDTO",
"WriteDTO",
"ReadDTO",
"UserDTO",
"UserAuthDTO",
"UserWriteDTO",
"UserReadDTO",
"UserDBDTO",
"RefDTO",
"RefReadDTO",
"TokenSchema",
)

View File

@ -1,5 +0,0 @@
from pydantic import BaseModel
class RefDTO(BaseModel):
refer_code: str

View File

@ -0,0 +1,11 @@
from uuid import UUID
from . import BaseDTO
class RefDTO(BaseDTO):
referer: str
class RefReadDTO(RefDTO):
refer: UUID

View File

@ -6,6 +6,10 @@ class UserDTO(WriteDTO):
email: str
class UserAuthDTO(UserDTO):
id: str
class UserWriteDTO(UserDTO):
password: str

View File

@ -1,3 +1,8 @@
from .auth_service import AuthService, get_current_user
from .user import UserService
__all__ = ("UserService",)
__all__ = (
"UserService",
"AuthService",
"get_current_user",
)

View File

@ -1,14 +1,15 @@
from datetime import datetime, timedelta
from uuid import UUID
from fastapi import Depends, HTTPException
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
from jose import JWTError, jwt
from passlib.context import CryptContext
from test_api.schemas.token import TokenSchema
from test_api.schemas.user_schema import UserDTO
from test_api.uow.uow_base import UnitOfWork
from ..schemas import TokenSchema, UserAuthDTO
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_schema = OAuth2PasswordBearer(tokenUrl="/token")
@ -21,6 +22,7 @@ async def get_current_user(token: str = Depends(oauth2_schema)):
try:
payload = jwt.decode(token, "fsgddfsgdfgs", algorithms=["HS256"])
id: str = payload.get("id", "")
name: str = payload.get("name", "")
sub: str = payload.get("sub", "")
expires_at: str = payload.get("expires_at", "")
@ -32,7 +34,7 @@ async def get_current_user(token: str = Depends(oauth2_schema)):
if is_expired(expires_at):
raise HTTPException(401, "Invalid credentials")
return UserDTO(name=name, email=sub)
return UserAuthDTO(id=id, name=name, email=sub)
except JWTError:
raise HTTPException(401, "Invalid credentials")
@ -58,13 +60,14 @@ class AuthService:
if not self.crypto_context.verify(login.password, user.hashed_password):
raise HTTPException(401, "Incorrect password")
else:
access_token = self._create_access_token(user.name, user.email)
access_token = self._create_access_token(user.id, user.name, user.email)
return TokenSchema(access_token=access_token, token_type="bearer")
def _create_access_token(self, name: str, email: str) -> str:
def _create_access_token(self, id: UUID, name: str, email: str) -> str:
"""Encode user information and expiration time."""
payload = {
"id": str(id),
"name": name,
"sub": email,
"expires_at": self._expiration_time(),

View File

@ -22,7 +22,8 @@ class UserService:
user = await self.uow.users.find_one(filter={"email": email})
return user
async def create_refer_code(self, email: str):
async def create_ref(self, email: str):
async with self.uow:
user = await self.uow.users.find_one(filter={"email": email})
print(user)
res = await self.uow.ref.add_ref(referer=user.id)
return res

View File

@ -2,13 +2,14 @@ from types import TracebackType
from fastapi import HTTPException
from ..repositories import UserRepository
from ..repositories import RefRepository, UserRepository
class UnitOfWork:
def __init__(self, session_factory):
self.session = session_factory()
self.users = UserRepository(self.session)
self.ref = RefRepository(self.session)
async def __aenter__(self):
return self