From 15549f53d51ff55957607a215c42ee6d3b4bb1b1 Mon Sep 17 00:00:00 2001 From: Sergey Vanyushkin Date: Sat, 23 Mar 2024 08:28:07 +0000 Subject: [PATCH] add user id in jwt token --- test_api/schemas/user_schema.py | 4 +++- test_api/services/auth_service.py | 8 +++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/test_api/schemas/user_schema.py b/test_api/schemas/user_schema.py index 0c6a45f..fb03bc9 100644 --- a/test_api/schemas/user_schema.py +++ b/test_api/schemas/user_schema.py @@ -1,3 +1,5 @@ +from uuid import UUID + from . import ReadDTO, WriteDTO @@ -7,7 +9,7 @@ class UserDTO(WriteDTO): class UserAuthDTO(UserDTO): - id: str + id: UUID class UserWriteDTO(UserDTO): diff --git a/test_api/services/auth_service.py b/test_api/services/auth_service.py index d01e96d..a2b23d7 100644 --- a/test_api/services/auth_service.py +++ b/test_api/services/auth_service.py @@ -22,7 +22,7 @@ async def get_current_user(token: str = Depends(oauth2_schema)): try: payload = jwt.decode(token, "fsgddfsgdfgs", algorithms=["HS256"]) - id: str = payload.get("id", "") + id: UUID = UUID(payload.get("id")) name: str = payload.get("name", "") sub: str = payload.get("sub", "") expires_at: str = payload.get("expires_at", "") @@ -33,7 +33,6 @@ async def get_current_user(token: str = Depends(oauth2_schema)): if expires_at: if is_expired(expires_at): raise HTTPException(401, "Invalid credentials") - return UserAuthDTO(id=id, name=name, email=sub) except JWTError: raise HTTPException(401, "Invalid credentials") @@ -50,7 +49,9 @@ class AuthService: self.uow = uow self.crypto_context = CryptContext(schemes="bcrypt") - async def authenticate(self, login: OAuth2PasswordRequestForm = Depends()) -> TokenSchema | None: + async def authenticate( + self, login: OAuth2PasswordRequestForm = Depends() + ) -> TokenSchema | None: async with self.uow: user = await self.uow.users.find_one(filter={"email": login.username}) @@ -73,6 +74,7 @@ class AuthService: "expires_at": self._expiration_time(), } + print(payload) return jwt.encode(payload, "fsgddfsgdfgs", algorithm="HS256") @staticmethod