depends.py 1.7 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. import json
  4. from typing import Generator, Union, Optional
  5. import jwt
  6. from fastapi import security, HTTPException, status, Depends
  7. from fastapi.security import OAuth2PasswordBearer
  8. from pydantic import ValidationError
  9. from sqlalchemy.ext.asyncio import AsyncSession
  10. from core import security
  11. from core.config import settings
  12. from crud.user import crud_admin, crud_teacher, crud_student
  13. from db.asyncsession import LocalAsyncSession
  14. from models.user import Teacher, Student, SysUser
  15. from schemas.auth import TokenPayload
  16. reusable_oauth2 = OAuth2PasswordBearer(tokenUrl=f"/{settings.API_V1_STR}/login")
  17. async def get_async_db() -> Generator:
  18. async with LocalAsyncSession() as db:
  19. yield db
  20. def check_access_token(token: str):
  21. try:
  22. payload = jwt.decode(token,
  23. settings.SECRET_KEY,
  24. algorithms=[security.ALGORITHM])
  25. payload["sub"] = json.loads(payload["sub"])
  26. token_payload = TokenPayload(**payload)
  27. except (jwt.PyJWTError, ValidationError):
  28. raise HTTPException(status_code=status.HTTP_403_FORBIDDEN,
  29. detail="Invalid Access Token!!")
  30. return token_payload
  31. async def get_current_user(token: str = Depends(reusable_oauth2),
  32. db: AsyncSession = Depends(get_async_db)
  33. ) -> Optional[Union[SysUser, Teacher, Student]]:
  34. token_payload = check_access_token(token)
  35. if token_payload.sub["utype"] == 0:
  36. crud = crud_admin
  37. elif token_payload.sub["utype"] == 1:
  38. crud = crud_teacher
  39. else:
  40. crud = crud_student
  41. user = await crud.find_one(db, {"username": token_payload.sub["sub"]})
  42. return user