import base64 import hashlib import hmac import json import secrets import time from dataclasses import dataclass from typing import Any from app.lib.jwt.blacklist import InMemoryTokenBlacklist from app.lib.jwt.exceptions import JwtExpiredError, JwtInvalidError from app.lib.jwt.token import JwtToken @dataclass(frozen=True, slots=True) class JwtSceneConfig: secret: str issuer: str ttl: int refresh_ttl: int blacklist_ttl: int class Jwt: def __init__(self, config: JwtSceneConfig, blacklist: InMemoryTokenBlacklist) -> None: self.config = config self.blacklist = blacklist @property def issuer(self) -> str: return self.config.issuer def builder_access_token(self, admin_id: str) -> str: return self._build_token(admin_id, token_type="access", ttl=self.config.ttl) def builder_refresh_token(self, admin_id: str) -> str: return self._build_token(admin_id, token_type="refresh", ttl=self.config.refresh_ttl) async def parser_access_token(self, raw_token: str) -> JwtToken: token = self._parse(raw_token) if token.is_refresh: raise JwtInvalidError("Token is a refresh token") self._validate_issuer(token) return token async def parser_refresh_token(self, raw_token: str) -> JwtToken: token = self._parse(raw_token) if not token.is_refresh: raise JwtInvalidError("Token is not a refresh token") self._validate_issuer(token) return token async def add_blacklist(self, token: JwtToken) -> None: await self.blacklist.add(token.raw, self.config.blacklist_ttl) async def has_blacklist(self, token: JwtToken) -> bool: return await self.blacklist.has(token.raw) def get_config(self, key: str, default: Any = None) -> Any: return getattr(self.config, key, default) def _build_token(self, admin_id: str, token_type: str, ttl: int) -> str: now = int(time.time()) payload: dict[str, Any] = { "iss": self.config.issuer, "jti": str(admin_id), "iat": now, "nbf": now, "exp": now + ttl, "token_type": token_type, "nonce": secrets.token_urlsafe(16), } if token_type == "refresh": payload["sub"] = "refresh" header = {"typ": "JWT", "alg": "HS256"} signing_input = ".".join( [ self._b64url_json(header), self._b64url_json(payload), ] ) signature = hmac.new( self.config.secret.encode("utf-8"), signing_input.encode("ascii"), hashlib.sha256, ).digest() return f"{signing_input}.{self._b64url_encode(signature)}" def _parse(self, raw_token: str) -> JwtToken: try: header_b64, payload_b64, signature_b64 = raw_token.split(".", 2) except ValueError as exc: raise JwtInvalidError("Malformed token") from exc signing_input = f"{header_b64}.{payload_b64}".encode("ascii") expected = hmac.new( self.config.secret.encode("utf-8"), signing_input, hashlib.sha256, ).digest() if not hmac.compare_digest(self._b64url_encode(expected), signature_b64): raise JwtInvalidError("Invalid signature") try: header = json.loads(self._b64url_decode(header_b64)) payload = json.loads(self._b64url_decode(payload_b64)) except (json.JSONDecodeError, UnicodeDecodeError) as exc: raise JwtInvalidError("Invalid token payload") from exc if header.get("alg") != "HS256": raise JwtInvalidError("Unsupported algorithm") if not isinstance(payload, dict): raise JwtInvalidError("Invalid claims") token = JwtToken(raw=raw_token, claims=payload) self._validate_time(token) return token def _validate_time(self, token: JwtToken) -> None: now = int(time.time()) exp = token.claims.get("exp") nbf = token.claims.get("nbf") iat = token.claims.get("iat") if not isinstance(exp, int): raise JwtInvalidError("Missing exp") if exp <= now: raise JwtExpiredError("Token expired") if isinstance(nbf, int) and nbf > now: raise JwtInvalidError("Token not active") if isinstance(iat, int) and iat > now + 60: raise JwtInvalidError("Token issued in future") def _validate_issuer(self, token: JwtToken) -> None: if token.claims.get("iss") != self.config.issuer: raise JwtInvalidError("Invalid issuer") @staticmethod def _b64url_encode(raw: bytes) -> str: return base64.urlsafe_b64encode(raw).decode("ascii").rstrip("=") @classmethod def _b64url_json(cls, value: dict[str, Any]) -> str: raw = json.dumps(value, separators=(",", ":"), sort_keys=True).encode("utf-8") return cls._b64url_encode(raw) @staticmethod def _b64url_decode(value: str) -> str: padding = "=" * (-len(value) % 4) return base64.urlsafe_b64decode(value + padding).decode("utf-8")