sync
parent
0d9d69d206
commit
35bb1c5fb5
|
@ -1,9 +1,7 @@
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
from test_api.services.auth import AuthService
|
|
||||||
|
|
||||||
from .config import get_settings
|
from .config import get_settings
|
||||||
from .services.user import UserService
|
from .services import AuthService, UserService
|
||||||
from .uow.uow_base import UnitOfWork
|
from .uow.uow_base import UnitOfWork
|
||||||
|
|
||||||
async_engine = create_async_engine(
|
async_engine = create_async_engine(
|
||||||
|
|
|
@ -0,0 +1,40 @@
|
||||||
|
"""empty message
|
||||||
|
|
||||||
|
Revision ID: 9d8cc4d3d62f
|
||||||
|
Revises: 41b339082609
|
||||||
|
Create Date: 2024-03-23 03:08:05.657592
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "9d8cc4d3d62f"
|
||||||
|
down_revision: str | None = "41b339082609"
|
||||||
|
branch_labels: str | Sequence[str] | None = None
|
||||||
|
depends_on: str | Sequence[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table(
|
||||||
|
"refers",
|
||||||
|
sa.Column("referer", sa.UUID(), nullable=False),
|
||||||
|
sa.Column("is_active", sa.Boolean(), nullable=False),
|
||||||
|
sa.Column("id", sa.UUID(), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["referer"],
|
||||||
|
["user.id"],
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_table("refers")
|
||||||
|
# ### end Alembic commands ###
|
|
@ -1,7 +1,9 @@
|
||||||
from .base_model import Base
|
from .base_model import Base
|
||||||
|
from .ref_model import RefModel
|
||||||
from .user_model import UserModel
|
from .user_model import UserModel
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
"Base",
|
"Base",
|
||||||
"UserModel",
|
"UserModel",
|
||||||
|
"RefModel",
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,18 @@
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import ForeignKey
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from . import Base
|
||||||
|
|
||||||
|
|
||||||
|
class RefModel(Base):
|
||||||
|
__tablename__ = "refers"
|
||||||
|
|
||||||
|
referer: Mapped[UUID] = mapped_column(ForeignKey("user.id"))
|
||||||
|
referals: Mapped[list["UserModel"]] = relationship(
|
||||||
|
"UserModel",
|
||||||
|
backref="refers",
|
||||||
|
lazy="selectin",
|
||||||
|
)
|
||||||
|
is_active: Mapped[bool]
|
|
@ -1,14 +0,0 @@
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from sqlalchemy import ForeignKey
|
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
|
||||||
|
|
||||||
from . import Base
|
|
||||||
|
|
||||||
|
|
||||||
class ReferModel(Base):
|
|
||||||
__tablename__ = "referals"
|
|
||||||
|
|
||||||
owner: Mapped[UUID] = mapped_column(ForeignKey("user.id"))
|
|
||||||
email: Mapped[str] = mapped_column(unique=True)
|
|
||||||
hashed_password: Mapped[str]
|
|
|
@ -1,3 +1,7 @@
|
||||||
|
from .ref_repo import RefRepository
|
||||||
from .user import UserRepository
|
from .user import UserRepository
|
||||||
|
|
||||||
__all__ = ("UserRepository",)
|
__all__ = (
|
||||||
|
"UserRepository",
|
||||||
|
"RefRepository",
|
||||||
|
)
|
||||||
|
|
|
@ -0,0 +1,26 @@
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import insert
|
||||||
|
from sqlalchemy.ext.asyncio.session import AsyncSession
|
||||||
|
|
||||||
|
from ..models import RefModel
|
||||||
|
from ..schemas import RefReadDTO
|
||||||
|
|
||||||
|
|
||||||
|
class RefRepository:
|
||||||
|
def __init__(self, session: AsyncSession):
|
||||||
|
self.session = session
|
||||||
|
|
||||||
|
async def add_ref(self, referer: UUID) -> RefReadDTO:
|
||||||
|
stmt = insert(RefModel).values(referer=referer, is_active=True).returning(RefModel)
|
||||||
|
res = await self.session.execute(stmt)
|
||||||
|
return RefReadDTO.model_validate(res.scalar_one())
|
||||||
|
|
||||||
|
# async def delete_ref(self, filter: dict) -> UserReadDTO | None:
|
||||||
|
# stmt = select(UserModel).filter_by(**filter)
|
||||||
|
# res = await self.session.execute(stmt)
|
||||||
|
# res = res.scalar_one_or_none()
|
||||||
|
#
|
||||||
|
# if res is not None:
|
||||||
|
# return UserReadDTO.model_validate(res)
|
||||||
|
# return None
|
|
@ -1,12 +1,9 @@
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
|
||||||
from test_api.schemas.token import TokenSchema
|
|
||||||
from test_api.schemas.user_schema import UserReadDTO, UserWriteDTO
|
|
||||||
from test_api.services.auth import AuthService
|
|
||||||
from test_api.services.user import UserService
|
|
||||||
|
|
||||||
from ..di import get_auth_service, get_user_service
|
from ..di import get_auth_service, get_user_service
|
||||||
|
from ..schemas import TokenSchema, UserReadDTO, UserWriteDTO
|
||||||
|
from ..services import AuthService, UserService
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,8 @@
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
from test_api.services.auth import get_current_user
|
|
||||||
|
|
||||||
from ..di import get_user_service
|
from ..di import get_user_service
|
||||||
from ..schemas import RefDTO, UserDTO, UserReadDTO
|
from ..schemas import RefReadDTO, UserAuthDTO, UserDTO, UserReadDTO
|
||||||
from ..services import UserService
|
from ..services import UserService, get_current_user
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
prefix="/user",
|
prefix="/user",
|
||||||
|
@ -19,9 +17,12 @@ async def get_user(
|
||||||
return await user_service.get_user(email=user.email)
|
return await user_service.get_user(email=user.email)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me/create_ref", status_code=201, response_model=RefDTO)
|
@router.get("/me/create_ref", status_code=201, response_model=RefReadDTO)
|
||||||
async def create_refer_code(
|
async def create_refer_code(
|
||||||
user_service: UserService = Depends(get_user_service),
|
user_service: UserService = Depends(get_user_service),
|
||||||
user: UserDTO = Depends(get_current_user),
|
user: UserAuthDTO = Depends(get_current_user),
|
||||||
) -> RefDTO | None:
|
):
|
||||||
return await user_service.create_refer_code(email=user.email)
|
print(user)
|
||||||
|
res = await user_service.create_ref(email=user.email)
|
||||||
|
|
||||||
|
return res
|
||||||
|
|
|
@ -1,13 +1,18 @@
|
||||||
from .base_schema import BaseDTO, ReadDTO, WriteDTO
|
from .base_schema import BaseDTO, ReadDTO, WriteDTO
|
||||||
from .ref import RefDTO
|
from .ref_schema import RefDTO, RefReadDTO
|
||||||
from .user_schema import UserDTO, UserReadDTO, UserWriteDTO
|
from .token_schema import TokenSchema
|
||||||
|
from .user_schema import UserAuthDTO, UserDBDTO, UserDTO, UserReadDTO, UserWriteDTO
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
"BaseDTO",
|
"BaseDTO",
|
||||||
"WriteDTO",
|
"WriteDTO",
|
||||||
"ReadDTO",
|
"ReadDTO",
|
||||||
"UserDTO",
|
"UserDTO",
|
||||||
|
"UserAuthDTO",
|
||||||
"UserWriteDTO",
|
"UserWriteDTO",
|
||||||
"UserReadDTO",
|
"UserReadDTO",
|
||||||
|
"UserDBDTO",
|
||||||
"RefDTO",
|
"RefDTO",
|
||||||
|
"RefReadDTO",
|
||||||
|
"TokenSchema",
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,5 +0,0 @@
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class RefDTO(BaseModel):
|
|
||||||
refer_code: str
|
|
|
@ -0,0 +1,11 @@
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from . import BaseDTO
|
||||||
|
|
||||||
|
|
||||||
|
class RefDTO(BaseDTO):
|
||||||
|
referer: str
|
||||||
|
|
||||||
|
|
||||||
|
class RefReadDTO(RefDTO):
|
||||||
|
refer: UUID
|
|
@ -6,6 +6,10 @@ class UserDTO(WriteDTO):
|
||||||
email: str
|
email: str
|
||||||
|
|
||||||
|
|
||||||
|
class UserAuthDTO(UserDTO):
|
||||||
|
id: str
|
||||||
|
|
||||||
|
|
||||||
class UserWriteDTO(UserDTO):
|
class UserWriteDTO(UserDTO):
|
||||||
password: str
|
password: str
|
||||||
|
|
||||||
|
|
|
@ -1,3 +1,8 @@
|
||||||
|
from .auth_service import AuthService, get_current_user
|
||||||
from .user import UserService
|
from .user import UserService
|
||||||
|
|
||||||
__all__ = ("UserService",)
|
__all__ = (
|
||||||
|
"UserService",
|
||||||
|
"AuthService",
|
||||||
|
"get_current_user",
|
||||||
|
)
|
||||||
|
|
|
@ -1,14 +1,15 @@
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException
|
from fastapi import Depends, HTTPException
|
||||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
from passlib.context import CryptContext
|
from passlib.context import CryptContext
|
||||||
|
|
||||||
from test_api.schemas.token import TokenSchema
|
|
||||||
from test_api.schemas.user_schema import UserDTO
|
|
||||||
from test_api.uow.uow_base import UnitOfWork
|
from test_api.uow.uow_base import UnitOfWork
|
||||||
|
|
||||||
|
from ..schemas import TokenSchema, UserAuthDTO
|
||||||
|
|
||||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||||
|
|
||||||
oauth2_schema = OAuth2PasswordBearer(tokenUrl="/token")
|
oauth2_schema = OAuth2PasswordBearer(tokenUrl="/token")
|
||||||
|
@ -21,6 +22,7 @@ async def get_current_user(token: str = Depends(oauth2_schema)):
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(token, "fsgddfsgdfgs", algorithms=["HS256"])
|
payload = jwt.decode(token, "fsgddfsgdfgs", algorithms=["HS256"])
|
||||||
|
|
||||||
|
id: str = payload.get("id", "")
|
||||||
name: str = payload.get("name", "")
|
name: str = payload.get("name", "")
|
||||||
sub: str = payload.get("sub", "")
|
sub: str = payload.get("sub", "")
|
||||||
expires_at: str = payload.get("expires_at", "")
|
expires_at: str = payload.get("expires_at", "")
|
||||||
|
@ -32,7 +34,7 @@ async def get_current_user(token: str = Depends(oauth2_schema)):
|
||||||
if is_expired(expires_at):
|
if is_expired(expires_at):
|
||||||
raise HTTPException(401, "Invalid credentials")
|
raise HTTPException(401, "Invalid credentials")
|
||||||
|
|
||||||
return UserDTO(name=name, email=sub)
|
return UserAuthDTO(id=id, name=name, email=sub)
|
||||||
except JWTError:
|
except JWTError:
|
||||||
raise HTTPException(401, "Invalid credentials")
|
raise HTTPException(401, "Invalid credentials")
|
||||||
|
|
||||||
|
@ -58,13 +60,14 @@ class AuthService:
|
||||||
if not self.crypto_context.verify(login.password, user.hashed_password):
|
if not self.crypto_context.verify(login.password, user.hashed_password):
|
||||||
raise HTTPException(401, "Incorrect password")
|
raise HTTPException(401, "Incorrect password")
|
||||||
else:
|
else:
|
||||||
access_token = self._create_access_token(user.name, user.email)
|
access_token = self._create_access_token(user.id, user.name, user.email)
|
||||||
return TokenSchema(access_token=access_token, token_type="bearer")
|
return TokenSchema(access_token=access_token, token_type="bearer")
|
||||||
|
|
||||||
def _create_access_token(self, name: str, email: str) -> str:
|
def _create_access_token(self, id: UUID, name: str, email: str) -> str:
|
||||||
"""Encode user information and expiration time."""
|
"""Encode user information and expiration time."""
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
|
"id": str(id),
|
||||||
"name": name,
|
"name": name,
|
||||||
"sub": email,
|
"sub": email,
|
||||||
"expires_at": self._expiration_time(),
|
"expires_at": self._expiration_time(),
|
|
@ -22,7 +22,8 @@ class UserService:
|
||||||
user = await self.uow.users.find_one(filter={"email": email})
|
user = await self.uow.users.find_one(filter={"email": email})
|
||||||
return user
|
return user
|
||||||
|
|
||||||
async def create_refer_code(self, email: str):
|
async def create_ref(self, email: str):
|
||||||
async with self.uow:
|
async with self.uow:
|
||||||
user = await self.uow.users.find_one(filter={"email": email})
|
user = await self.uow.users.find_one(filter={"email": email})
|
||||||
print(user)
|
res = await self.uow.ref.add_ref(referer=user.id)
|
||||||
|
return res
|
||||||
|
|
|
@ -2,13 +2,14 @@ from types import TracebackType
|
||||||
|
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
|
||||||
from ..repositories import UserRepository
|
from ..repositories import RefRepository, UserRepository
|
||||||
|
|
||||||
|
|
||||||
class UnitOfWork:
|
class UnitOfWork:
|
||||||
def __init__(self, session_factory):
|
def __init__(self, session_factory):
|
||||||
self.session = session_factory()
|
self.session = session_factory()
|
||||||
self.users = UserRepository(self.session)
|
self.users = UserRepository(self.session)
|
||||||
|
self.ref = RefRepository(self.session)
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
Loading…
Reference in New Issue