sync
parent
0d9d69d206
commit
35bb1c5fb5
|
@ -1,9 +1,7 @@
|
|||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from test_api.services.auth import AuthService
|
||||
|
||||
from .config import get_settings
|
||||
from .services.user import UserService
|
||||
from .services import AuthService, UserService
|
||||
from .uow.uow_base import UnitOfWork
|
||||
|
||||
async_engine = create_async_engine(
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
"""empty message
|
||||
|
||||
Revision ID: 9d8cc4d3d62f
|
||||
Revises: 41b339082609
|
||||
Create Date: 2024-03-23 03:08:05.657592
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "9d8cc4d3d62f"
|
||||
down_revision: str | None = "41b339082609"
|
||||
branch_labels: str | Sequence[str] | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"refers",
|
||||
sa.Column("referer", sa.UUID(), nullable=False),
|
||||
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||
sa.Column("id", sa.UUID(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["referer"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("refers")
|
||||
# ### end Alembic commands ###
|
|
@ -1,7 +1,9 @@
|
|||
from .base_model import Base
|
||||
from .ref_model import RefModel
|
||||
from .user_model import UserModel
|
||||
|
||||
__all__ = (
|
||||
"Base",
|
||||
"UserModel",
|
||||
"RefModel",
|
||||
)
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from . import Base
|
||||
|
||||
|
||||
class RefModel(Base):
|
||||
__tablename__ = "refers"
|
||||
|
||||
referer: Mapped[UUID] = mapped_column(ForeignKey("user.id"))
|
||||
referals: Mapped[list["UserModel"]] = relationship(
|
||||
"UserModel",
|
||||
backref="refers",
|
||||
lazy="selectin",
|
||||
)
|
||||
is_active: Mapped[bool]
|
|
@ -1,14 +0,0 @@
|
|||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from . import Base
|
||||
|
||||
|
||||
class ReferModel(Base):
|
||||
__tablename__ = "referals"
|
||||
|
||||
owner: Mapped[UUID] = mapped_column(ForeignKey("user.id"))
|
||||
email: Mapped[str] = mapped_column(unique=True)
|
||||
hashed_password: Mapped[str]
|
|
@ -1,3 +1,7 @@
|
|||
from .ref_repo import RefRepository
|
||||
from .user import UserRepository
|
||||
|
||||
__all__ = ("UserRepository",)
|
||||
__all__ = (
|
||||
"UserRepository",
|
||||
"RefRepository",
|
||||
)
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import insert
|
||||
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||
|
||||
from ..models import RefModel
|
||||
from ..schemas import RefReadDTO
|
||||
|
||||
|
||||
class RefRepository:
|
||||
def __init__(self, session: AsyncSession):
|
||||
self.session = session
|
||||
|
||||
async def add_ref(self, referer: UUID) -> RefReadDTO:
|
||||
stmt = insert(RefModel).values(referer=referer, is_active=True).returning(RefModel)
|
||||
res = await self.session.execute(stmt)
|
||||
return RefReadDTO.model_validate(res.scalar_one())
|
||||
|
||||
# async def delete_ref(self, filter: dict) -> UserReadDTO | None:
|
||||
# stmt = select(UserModel).filter_by(**filter)
|
||||
# res = await self.session.execute(stmt)
|
||||
# res = res.scalar_one_or_none()
|
||||
#
|
||||
# if res is not None:
|
||||
# return UserReadDTO.model_validate(res)
|
||||
# return None
|
|
@ -1,12 +1,9 @@
|
|||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
from test_api.schemas.token import TokenSchema
|
||||
from test_api.schemas.user_schema import UserReadDTO, UserWriteDTO
|
||||
from test_api.services.auth import AuthService
|
||||
from test_api.services.user import UserService
|
||||
|
||||
from ..di import get_auth_service, get_user_service
|
||||
from ..schemas import TokenSchema, UserReadDTO, UserWriteDTO
|
||||
from ..services import AuthService, UserService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
from fastapi import APIRouter, Depends
|
||||
|
||||
from test_api.services.auth import get_current_user
|
||||
|
||||
from ..di import get_user_service
|
||||
from ..schemas import RefDTO, UserDTO, UserReadDTO
|
||||
from ..services import UserService
|
||||
from ..schemas import RefReadDTO, UserAuthDTO, UserDTO, UserReadDTO
|
||||
from ..services import UserService, get_current_user
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/user",
|
||||
|
@ -19,9 +17,12 @@ async def get_user(
|
|||
return await user_service.get_user(email=user.email)
|
||||
|
||||
|
||||
@router.get("/me/create_ref", status_code=201, response_model=RefDTO)
|
||||
@router.get("/me/create_ref", status_code=201, response_model=RefReadDTO)
|
||||
async def create_refer_code(
|
||||
user_service: UserService = Depends(get_user_service),
|
||||
user: UserDTO = Depends(get_current_user),
|
||||
) -> RefDTO | None:
|
||||
return await user_service.create_refer_code(email=user.email)
|
||||
user: UserAuthDTO = Depends(get_current_user),
|
||||
):
|
||||
print(user)
|
||||
res = await user_service.create_ref(email=user.email)
|
||||
|
||||
return res
|
||||
|
|
|
@ -1,13 +1,18 @@
|
|||
from .base_schema import BaseDTO, ReadDTO, WriteDTO
|
||||
from .ref import RefDTO
|
||||
from .user_schema import UserDTO, UserReadDTO, UserWriteDTO
|
||||
from .ref_schema import RefDTO, RefReadDTO
|
||||
from .token_schema import TokenSchema
|
||||
from .user_schema import UserAuthDTO, UserDBDTO, UserDTO, UserReadDTO, UserWriteDTO
|
||||
|
||||
__all__ = (
|
||||
"BaseDTO",
|
||||
"WriteDTO",
|
||||
"ReadDTO",
|
||||
"UserDTO",
|
||||
"UserAuthDTO",
|
||||
"UserWriteDTO",
|
||||
"UserReadDTO",
|
||||
"UserDBDTO",
|
||||
"RefDTO",
|
||||
"RefReadDTO",
|
||||
"TokenSchema",
|
||||
)
|
||||
|
|
|
@ -1,5 +0,0 @@
|
|||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class RefDTO(BaseModel):
|
||||
refer_code: str
|
|
@ -0,0 +1,11 @@
|
|||
from uuid import UUID
|
||||
|
||||
from . import BaseDTO
|
||||
|
||||
|
||||
class RefDTO(BaseDTO):
|
||||
referer: str
|
||||
|
||||
|
||||
class RefReadDTO(RefDTO):
|
||||
refer: UUID
|
|
@ -6,6 +6,10 @@ class UserDTO(WriteDTO):
|
|||
email: str
|
||||
|
||||
|
||||
class UserAuthDTO(UserDTO):
|
||||
id: str
|
||||
|
||||
|
||||
class UserWriteDTO(UserDTO):
|
||||
password: str
|
||||
|
||||
|
|
|
@ -1,3 +1,8 @@
|
|||
from .auth_service import AuthService, get_current_user
|
||||
from .user import UserService
|
||||
|
||||
__all__ = ("UserService",)
|
||||
__all__ = (
|
||||
"UserService",
|
||||
"AuthService",
|
||||
"get_current_user",
|
||||
)
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
from datetime import datetime, timedelta
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||
from jose import JWTError, jwt
|
||||
from passlib.context import CryptContext
|
||||
|
||||
from test_api.schemas.token import TokenSchema
|
||||
from test_api.schemas.user_schema import UserDTO
|
||||
from test_api.uow.uow_base import UnitOfWork
|
||||
|
||||
from ..schemas import TokenSchema, UserAuthDTO
|
||||
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
oauth2_schema = OAuth2PasswordBearer(tokenUrl="/token")
|
||||
|
@ -21,6 +22,7 @@ async def get_current_user(token: str = Depends(oauth2_schema)):
|
|||
try:
|
||||
payload = jwt.decode(token, "fsgddfsgdfgs", algorithms=["HS256"])
|
||||
|
||||
id: str = payload.get("id", "")
|
||||
name: str = payload.get("name", "")
|
||||
sub: str = payload.get("sub", "")
|
||||
expires_at: str = payload.get("expires_at", "")
|
||||
|
@ -32,7 +34,7 @@ async def get_current_user(token: str = Depends(oauth2_schema)):
|
|||
if is_expired(expires_at):
|
||||
raise HTTPException(401, "Invalid credentials")
|
||||
|
||||
return UserDTO(name=name, email=sub)
|
||||
return UserAuthDTO(id=id, name=name, email=sub)
|
||||
except JWTError:
|
||||
raise HTTPException(401, "Invalid credentials")
|
||||
|
||||
|
@ -58,13 +60,14 @@ class AuthService:
|
|||
if not self.crypto_context.verify(login.password, user.hashed_password):
|
||||
raise HTTPException(401, "Incorrect password")
|
||||
else:
|
||||
access_token = self._create_access_token(user.name, user.email)
|
||||
access_token = self._create_access_token(user.id, user.name, user.email)
|
||||
return TokenSchema(access_token=access_token, token_type="bearer")
|
||||
|
||||
def _create_access_token(self, name: str, email: str) -> str:
|
||||
def _create_access_token(self, id: UUID, name: str, email: str) -> str:
|
||||
"""Encode user information and expiration time."""
|
||||
|
||||
payload = {
|
||||
"id": str(id),
|
||||
"name": name,
|
||||
"sub": email,
|
||||
"expires_at": self._expiration_time(),
|
|
@ -22,7 +22,8 @@ class UserService:
|
|||
user = await self.uow.users.find_one(filter={"email": email})
|
||||
return user
|
||||
|
||||
async def create_refer_code(self, email: str):
|
||||
async def create_ref(self, email: str):
|
||||
async with self.uow:
|
||||
user = await self.uow.users.find_one(filter={"email": email})
|
||||
print(user)
|
||||
res = await self.uow.ref.add_ref(referer=user.id)
|
||||
return res
|
||||
|
|
|
@ -2,13 +2,14 @@ from types import TracebackType
|
|||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ..repositories import UserRepository
|
||||
from ..repositories import RefRepository, UserRepository
|
||||
|
||||
|
||||
class UnitOfWork:
|
||||
def __init__(self, session_factory):
|
||||
self.session = session_factory()
|
||||
self.users = UserRepository(self.session)
|
||||
self.ref = RefRepository(self.session)
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
|
Loading…
Reference in New Issue