Initial FastAPI admin auth scaffold
This commit is contained in:
151
app/lib/jwt/jwt.py
Normal file
151
app/lib/jwt/jwt.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import secrets
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from app.lib.jwt.blacklist import InMemoryTokenBlacklist
|
||||
from app.lib.jwt.exceptions import JwtExpiredError, JwtInvalidError
|
||||
from app.lib.jwt.token import JwtToken
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class JwtSceneConfig:
|
||||
secret: str
|
||||
issuer: str
|
||||
ttl: int
|
||||
refresh_ttl: int
|
||||
blacklist_ttl: int
|
||||
|
||||
|
||||
class Jwt:
|
||||
def __init__(self, config: JwtSceneConfig, blacklist: InMemoryTokenBlacklist) -> None:
|
||||
self.config = config
|
||||
self.blacklist = blacklist
|
||||
|
||||
@property
|
||||
def issuer(self) -> str:
|
||||
return self.config.issuer
|
||||
|
||||
def builder_access_token(self, admin_id: str) -> str:
|
||||
return self._build_token(admin_id, token_type="access", ttl=self.config.ttl)
|
||||
|
||||
def builder_refresh_token(self, admin_id: str) -> str:
|
||||
return self._build_token(admin_id, token_type="refresh", ttl=self.config.refresh_ttl)
|
||||
|
||||
async def parser_access_token(self, raw_token: str) -> JwtToken:
|
||||
token = self._parse(raw_token)
|
||||
if token.is_refresh:
|
||||
raise JwtInvalidError("Token is a refresh token")
|
||||
self._validate_issuer(token)
|
||||
return token
|
||||
|
||||
async def parser_refresh_token(self, raw_token: str) -> JwtToken:
|
||||
token = self._parse(raw_token)
|
||||
if not token.is_refresh:
|
||||
raise JwtInvalidError("Token is not a refresh token")
|
||||
self._validate_issuer(token)
|
||||
return token
|
||||
|
||||
async def add_blacklist(self, token: JwtToken) -> None:
|
||||
await self.blacklist.add(token.raw, self.config.blacklist_ttl)
|
||||
|
||||
async def has_blacklist(self, token: JwtToken) -> bool:
|
||||
return await self.blacklist.has(token.raw)
|
||||
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
return getattr(self.config, key, default)
|
||||
|
||||
def _build_token(self, admin_id: str, token_type: str, ttl: int) -> str:
|
||||
now = int(time.time())
|
||||
payload: dict[str, Any] = {
|
||||
"iss": self.config.issuer,
|
||||
"jti": str(admin_id),
|
||||
"iat": now,
|
||||
"nbf": now,
|
||||
"exp": now + ttl,
|
||||
"token_type": token_type,
|
||||
"nonce": secrets.token_urlsafe(16),
|
||||
}
|
||||
if token_type == "refresh":
|
||||
payload["sub"] = "refresh"
|
||||
|
||||
header = {"typ": "JWT", "alg": "HS256"}
|
||||
signing_input = ".".join(
|
||||
[
|
||||
self._b64url_json(header),
|
||||
self._b64url_json(payload),
|
||||
]
|
||||
)
|
||||
signature = hmac.new(
|
||||
self.config.secret.encode("utf-8"),
|
||||
signing_input.encode("ascii"),
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
return f"{signing_input}.{self._b64url_encode(signature)}"
|
||||
|
||||
def _parse(self, raw_token: str) -> JwtToken:
|
||||
try:
|
||||
header_b64, payload_b64, signature_b64 = raw_token.split(".", 2)
|
||||
except ValueError as exc:
|
||||
raise JwtInvalidError("Malformed token") from exc
|
||||
|
||||
signing_input = f"{header_b64}.{payload_b64}".encode("ascii")
|
||||
expected = hmac.new(
|
||||
self.config.secret.encode("utf-8"),
|
||||
signing_input,
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
if not hmac.compare_digest(self._b64url_encode(expected), signature_b64):
|
||||
raise JwtInvalidError("Invalid signature")
|
||||
|
||||
try:
|
||||
header = json.loads(self._b64url_decode(header_b64))
|
||||
payload = json.loads(self._b64url_decode(payload_b64))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as exc:
|
||||
raise JwtInvalidError("Invalid token payload") from exc
|
||||
|
||||
if header.get("alg") != "HS256":
|
||||
raise JwtInvalidError("Unsupported algorithm")
|
||||
if not isinstance(payload, dict):
|
||||
raise JwtInvalidError("Invalid claims")
|
||||
|
||||
token = JwtToken(raw=raw_token, claims=payload)
|
||||
self._validate_time(token)
|
||||
return token
|
||||
|
||||
def _validate_time(self, token: JwtToken) -> None:
|
||||
now = int(time.time())
|
||||
exp = token.claims.get("exp")
|
||||
nbf = token.claims.get("nbf")
|
||||
iat = token.claims.get("iat")
|
||||
|
||||
if not isinstance(exp, int):
|
||||
raise JwtInvalidError("Missing exp")
|
||||
if exp <= now:
|
||||
raise JwtExpiredError("Token expired")
|
||||
if isinstance(nbf, int) and nbf > now:
|
||||
raise JwtInvalidError("Token not active")
|
||||
if isinstance(iat, int) and iat > now + 60:
|
||||
raise JwtInvalidError("Token issued in future")
|
||||
|
||||
def _validate_issuer(self, token: JwtToken) -> None:
|
||||
if token.claims.get("iss") != self.config.issuer:
|
||||
raise JwtInvalidError("Invalid issuer")
|
||||
|
||||
@staticmethod
|
||||
def _b64url_encode(raw: bytes) -> str:
|
||||
return base64.urlsafe_b64encode(raw).decode("ascii").rstrip("=")
|
||||
|
||||
@classmethod
|
||||
def _b64url_json(cls, value: dict[str, Any]) -> str:
|
||||
raw = json.dumps(value, separators=(",", ":"), sort_keys=True).encode("utf-8")
|
||||
return cls._b64url_encode(raw)
|
||||
|
||||
@staticmethod
|
||||
def _b64url_decode(value: str) -> str:
|
||||
padding = "=" * (-len(value) % 4)
|
||||
return base64.urlsafe_b64decode(value + padding).decode("utf-8")
|
||||
Reference in New Issue
Block a user