Initial FastAPI admin auth scaffold

This commit is contained in:
2026-06-05 17:10:30 +08:00
commit 5635da9ea5
65 changed files with 1407 additions and 0 deletions

151
app/lib/jwt/jwt.py Normal file
View File

@@ -0,0 +1,151 @@
import base64
import hashlib
import hmac
import json
import secrets
import time
from dataclasses import dataclass
from typing import Any
from app.lib.jwt.blacklist import InMemoryTokenBlacklist
from app.lib.jwt.exceptions import JwtExpiredError, JwtInvalidError
from app.lib.jwt.token import JwtToken
@dataclass(frozen=True, slots=True)
class JwtSceneConfig:
secret: str
issuer: str
ttl: int
refresh_ttl: int
blacklist_ttl: int
class Jwt:
def __init__(self, config: JwtSceneConfig, blacklist: InMemoryTokenBlacklist) -> None:
self.config = config
self.blacklist = blacklist
@property
def issuer(self) -> str:
return self.config.issuer
def builder_access_token(self, admin_id: str) -> str:
return self._build_token(admin_id, token_type="access", ttl=self.config.ttl)
def builder_refresh_token(self, admin_id: str) -> str:
return self._build_token(admin_id, token_type="refresh", ttl=self.config.refresh_ttl)
async def parser_access_token(self, raw_token: str) -> JwtToken:
token = self._parse(raw_token)
if token.is_refresh:
raise JwtInvalidError("Token is a refresh token")
self._validate_issuer(token)
return token
async def parser_refresh_token(self, raw_token: str) -> JwtToken:
token = self._parse(raw_token)
if not token.is_refresh:
raise JwtInvalidError("Token is not a refresh token")
self._validate_issuer(token)
return token
async def add_blacklist(self, token: JwtToken) -> None:
await self.blacklist.add(token.raw, self.config.blacklist_ttl)
async def has_blacklist(self, token: JwtToken) -> bool:
return await self.blacklist.has(token.raw)
def get_config(self, key: str, default: Any = None) -> Any:
return getattr(self.config, key, default)
def _build_token(self, admin_id: str, token_type: str, ttl: int) -> str:
now = int(time.time())
payload: dict[str, Any] = {
"iss": self.config.issuer,
"jti": str(admin_id),
"iat": now,
"nbf": now,
"exp": now + ttl,
"token_type": token_type,
"nonce": secrets.token_urlsafe(16),
}
if token_type == "refresh":
payload["sub"] = "refresh"
header = {"typ": "JWT", "alg": "HS256"}
signing_input = ".".join(
[
self._b64url_json(header),
self._b64url_json(payload),
]
)
signature = hmac.new(
self.config.secret.encode("utf-8"),
signing_input.encode("ascii"),
hashlib.sha256,
).digest()
return f"{signing_input}.{self._b64url_encode(signature)}"
def _parse(self, raw_token: str) -> JwtToken:
try:
header_b64, payload_b64, signature_b64 = raw_token.split(".", 2)
except ValueError as exc:
raise JwtInvalidError("Malformed token") from exc
signing_input = f"{header_b64}.{payload_b64}".encode("ascii")
expected = hmac.new(
self.config.secret.encode("utf-8"),
signing_input,
hashlib.sha256,
).digest()
if not hmac.compare_digest(self._b64url_encode(expected), signature_b64):
raise JwtInvalidError("Invalid signature")
try:
header = json.loads(self._b64url_decode(header_b64))
payload = json.loads(self._b64url_decode(payload_b64))
except (json.JSONDecodeError, UnicodeDecodeError) as exc:
raise JwtInvalidError("Invalid token payload") from exc
if header.get("alg") != "HS256":
raise JwtInvalidError("Unsupported algorithm")
if not isinstance(payload, dict):
raise JwtInvalidError("Invalid claims")
token = JwtToken(raw=raw_token, claims=payload)
self._validate_time(token)
return token
def _validate_time(self, token: JwtToken) -> None:
now = int(time.time())
exp = token.claims.get("exp")
nbf = token.claims.get("nbf")
iat = token.claims.get("iat")
if not isinstance(exp, int):
raise JwtInvalidError("Missing exp")
if exp <= now:
raise JwtExpiredError("Token expired")
if isinstance(nbf, int) and nbf > now:
raise JwtInvalidError("Token not active")
if isinstance(iat, int) and iat > now + 60:
raise JwtInvalidError("Token issued in future")
def _validate_issuer(self, token: JwtToken) -> None:
if token.claims.get("iss") != self.config.issuer:
raise JwtInvalidError("Invalid issuer")
@staticmethod
def _b64url_encode(raw: bytes) -> str:
return base64.urlsafe_b64encode(raw).decode("ascii").rstrip("=")
@classmethod
def _b64url_json(cls, value: dict[str, Any]) -> str:
raw = json.dumps(value, separators=(",", ":"), sort_keys=True).encode("utf-8")
return cls._b64url_encode(raw)
@staticmethod
def _b64url_decode(value: str) -> str:
padding = "=" * (-len(value) % 4)
return base64.urlsafe_b64decode(value + padding).decode("utf-8")