73 lines
2.4 KiB
Python
73 lines
2.4 KiB
Python
from fastapi import Depends, Request
|
|
|
|
from app.constants.result_code import ResultCode
|
|
from app.core.dependencies import get_token_service
|
|
from app.exception.err_exception import ErrException
|
|
from app.lib.jwt.exceptions import JwtError, JwtExpiredError
|
|
from app.lib.jwt.jwt import Jwt
|
|
from app.lib.jwt.token import JwtToken
|
|
from app.service.base_token_service import BaseTokenService
|
|
|
|
|
|
class AbstractTokenMiddleware:
|
|
scene = "admin"
|
|
|
|
async def __call__(
|
|
self,
|
|
request: Request,
|
|
token_service: BaseTokenService = Depends(get_token_service),
|
|
) -> JwtToken:
|
|
raw_token = self.get_token(request)
|
|
jwt = token_service.get_jwt(self.scene)
|
|
|
|
try:
|
|
token = await self.parser_token(jwt, raw_token)
|
|
await token_service.check_jwt(jwt, token)
|
|
self.check_issuer(jwt, token)
|
|
except JwtExpiredError as exc:
|
|
raise ErrException(
|
|
"token过期",
|
|
ResultCode.JWT_EXPIRED,
|
|
{"err_msg": str(exc)},
|
|
) from exc
|
|
except JwtError as exc:
|
|
raise ErrException(
|
|
"token错误",
|
|
ResultCode.JWT_ERROR,
|
|
{"err_msg": str(exc)},
|
|
) from exc
|
|
|
|
self.set_context(request, token)
|
|
request.state.token = token
|
|
return token
|
|
|
|
async def parser_token(self, jwt: Jwt, raw_token: str) -> JwtToken:
|
|
raise NotImplementedError
|
|
|
|
def set_context(self, request: Request, token: JwtToken) -> None:
|
|
raise NotImplementedError
|
|
|
|
@staticmethod
|
|
def check_issuer(jwt: Jwt, token: JwtToken) -> None:
|
|
if token.claims.get("iss") != jwt.issuer:
|
|
raise ErrException("token错误", ResultCode.JWT_ERROR)
|
|
|
|
@staticmethod
|
|
def get_token(request: Request) -> str:
|
|
authorization = request.headers.get("authorization")
|
|
if authorization:
|
|
scheme, _, token = authorization.partition(" ")
|
|
if scheme.lower() == "bearer" and token:
|
|
return token.strip()
|
|
return authorization.replace("Bearer ", "", 1).strip()
|
|
|
|
token_header = request.headers.get("token")
|
|
if token_header:
|
|
return token_header.strip()
|
|
|
|
token_query = request.query_params.get("token")
|
|
if token_query:
|
|
return token_query.strip()
|
|
|
|
raise ErrException("token缺失", ResultCode.JWT_ERROR)
|