from fastapi import Depends, Request from app.constants.result_code import ResultCode from app.core.dependencies import get_token_service from app.exception.err_exception import ErrException from app.lib.jwt.exceptions import JwtError, JwtExpiredError from app.lib.jwt.jwt import Jwt from app.lib.jwt.token import JwtToken from app.service.base_token_service import BaseTokenService class AbstractTokenMiddleware: scene = "admin" async def __call__( self, request: Request, token_service: BaseTokenService = Depends(get_token_service), ) -> JwtToken: raw_token = self.get_token(request) jwt = token_service.get_jwt(self.scene) try: token = await self.parser_token(jwt, raw_token) await token_service.check_jwt(jwt, token) self.check_issuer(jwt, token) except JwtExpiredError as exc: raise ErrException( "token过期", ResultCode.JWT_EXPIRED, {"err_msg": str(exc)}, ) from exc except JwtError as exc: raise ErrException( "token错误", ResultCode.JWT_ERROR, {"err_msg": str(exc)}, ) from exc self.set_context(request, token) request.state.token = token return token async def parser_token(self, jwt: Jwt, raw_token: str) -> JwtToken: raise NotImplementedError def set_context(self, request: Request, token: JwtToken) -> None: raise NotImplementedError @staticmethod def check_issuer(jwt: Jwt, token: JwtToken) -> None: if token.claims.get("iss") != jwt.issuer: raise ErrException("token错误", ResultCode.JWT_ERROR) @staticmethod def get_token(request: Request) -> str: authorization = request.headers.get("authorization") if authorization: scheme, _, token = authorization.partition(" ") if scheme.lower() == "bearer" and token: return token.strip() return authorization.replace("Bearer ", "", 1).strip() token_header = request.headers.get("token") if token_header: return token_header.strip() token_query = request.query_params.get("token") if token_query: return token_query.strip() raise ErrException("token缺失", ResultCode.JWT_ERROR)