depends.py 1.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. import json
  4. from typing import Generator, Union, Optional
  5. import jwt
  6. from fastapi import security, HTTPException, status, Depends
  7. from fastapi.security import OAuth2PasswordBearer
  8. from pydantic import ValidationError
  9. from sqlalchemy.ext.asyncio import AsyncSession
  10. from core import security
  11. from core.config import settings
  12. from crud.user import crud_admin, crud_teacher, crud_student
  13. from crud.sysdata.role import crud_role, crud_permission
  14. from db.asyncsession import LocalAsyncSession
  15. from models.user import Teacher, Student, Admin
  16. from schemas.auth import TokenPayload
  17. reusable_oauth2 = OAuth2PasswordBearer(tokenUrl=f"/{settings.API_V1_STR}/login")
  18. async def get_async_db() -> Generator:
  19. async with LocalAsyncSession() as db:
  20. yield db
  21. def check_access_token(token: str):
  22. try:
  23. payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[security.ALGORITHM])
  24. payload["sub"] = json.loads(payload["sub"])
  25. token_payload = TokenPayload(**payload)
  26. except (jwt.PyJWTError, ValidationError):
  27. raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid Access Token!!")
  28. return token_payload
  29. async def get_current_user(token: str = Depends(reusable_oauth2),
  30. db: AsyncSession = Depends(get_async_db)):
  31. token_payload = check_access_token(token)
  32. if token_payload.sub["utype"] == 0:
  33. crud = crud_admin
  34. elif token_payload.sub["utype"] == 1:
  35. crud = crud_teacher
  36. else:
  37. crud = crud_student
  38. user = await crud.find_one(db, filters={"username": token_payload.sub["sub"]})
  39. user.utype = token_payload.sub["utype"]
  40. if token_payload.sub["utype"] == 0:
  41. role_id = user.role_id
  42. role = await crud_role.find_one(db, filters={"id": role_id})
  43. user.pcodes = role.permission_codes.split(",") if role.permission_codes else []
  44. return user