Initial FastAPI admin auth scaffold
This commit is contained in:
1
app/lib/jwt/__init__.py
Normal file
1
app/lib/jwt/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""JWT helpers."""
|
||||
25
app/lib/jwt/blacklist.py
Normal file
25
app/lib/jwt/blacklist.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
|
||||
class InMemoryTokenBlacklist:
|
||||
def __init__(self) -> None:
|
||||
self._items: dict[str, float] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def add(self, token: str, ttl: int) -> None:
|
||||
now = time.time()
|
||||
async with self._lock:
|
||||
self._cleanup(now)
|
||||
self._items[token] = now + ttl
|
||||
|
||||
async def has(self, token: str) -> bool:
|
||||
now = time.time()
|
||||
async with self._lock:
|
||||
self._cleanup(now)
|
||||
return token in self._items
|
||||
|
||||
def _cleanup(self, now: float) -> None:
|
||||
expired = [token for token, expires_at in self._items.items() if expires_at <= now]
|
||||
for token in expired:
|
||||
self._items.pop(token, None)
|
||||
10
app/lib/jwt/exceptions.py
Normal file
10
app/lib/jwt/exceptions.py
Normal file
@@ -0,0 +1,10 @@
|
||||
class JwtError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class JwtExpiredError(JwtError):
|
||||
pass
|
||||
|
||||
|
||||
class JwtInvalidError(JwtError):
|
||||
pass
|
||||
34
app/lib/jwt/factory.py
Normal file
34
app/lib/jwt/factory.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from app.core.config import Settings
|
||||
from app.lib.jwt.blacklist import InMemoryTokenBlacklist
|
||||
from app.lib.jwt.jwt import Jwt, JwtSceneConfig
|
||||
|
||||
|
||||
class JwtFactory:
|
||||
def __init__(self, settings: Settings, blacklist: InMemoryTokenBlacklist) -> None:
|
||||
self.settings = settings
|
||||
self.blacklist = blacklist
|
||||
|
||||
def get(self, scene: str = "default") -> Jwt:
|
||||
if scene == "admin":
|
||||
return Jwt(
|
||||
JwtSceneConfig(
|
||||
secret=self.settings.admin_jwt_secret,
|
||||
issuer=f"{self.settings.app_name}_admin",
|
||||
ttl=self.settings.admin_jwt_ttl,
|
||||
refresh_ttl=self.settings.admin_jwt_refresh_ttl,
|
||||
blacklist_ttl=self.settings.jwt_blacklist_ttl,
|
||||
),
|
||||
self.blacklist,
|
||||
)
|
||||
|
||||
scene_name = "default" if scene == "default" else scene
|
||||
return Jwt(
|
||||
JwtSceneConfig(
|
||||
secret=self.settings.api_jwt_secret,
|
||||
issuer=f"{self.settings.app_name}_{scene_name}",
|
||||
ttl=self.settings.api_jwt_ttl,
|
||||
refresh_ttl=self.settings.api_jwt_refresh_ttl,
|
||||
blacklist_ttl=self.settings.jwt_blacklist_ttl,
|
||||
),
|
||||
self.blacklist,
|
||||
)
|
||||
151
app/lib/jwt/jwt.py
Normal file
151
app/lib/jwt/jwt.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import secrets
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from app.lib.jwt.blacklist import InMemoryTokenBlacklist
|
||||
from app.lib.jwt.exceptions import JwtExpiredError, JwtInvalidError
|
||||
from app.lib.jwt.token import JwtToken
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class JwtSceneConfig:
|
||||
secret: str
|
||||
issuer: str
|
||||
ttl: int
|
||||
refresh_ttl: int
|
||||
blacklist_ttl: int
|
||||
|
||||
|
||||
class Jwt:
|
||||
def __init__(self, config: JwtSceneConfig, blacklist: InMemoryTokenBlacklist) -> None:
|
||||
self.config = config
|
||||
self.blacklist = blacklist
|
||||
|
||||
@property
|
||||
def issuer(self) -> str:
|
||||
return self.config.issuer
|
||||
|
||||
def builder_access_token(self, admin_id: str) -> str:
|
||||
return self._build_token(admin_id, token_type="access", ttl=self.config.ttl)
|
||||
|
||||
def builder_refresh_token(self, admin_id: str) -> str:
|
||||
return self._build_token(admin_id, token_type="refresh", ttl=self.config.refresh_ttl)
|
||||
|
||||
async def parser_access_token(self, raw_token: str) -> JwtToken:
|
||||
token = self._parse(raw_token)
|
||||
if token.is_refresh:
|
||||
raise JwtInvalidError("Token is a refresh token")
|
||||
self._validate_issuer(token)
|
||||
return token
|
||||
|
||||
async def parser_refresh_token(self, raw_token: str) -> JwtToken:
|
||||
token = self._parse(raw_token)
|
||||
if not token.is_refresh:
|
||||
raise JwtInvalidError("Token is not a refresh token")
|
||||
self._validate_issuer(token)
|
||||
return token
|
||||
|
||||
async def add_blacklist(self, token: JwtToken) -> None:
|
||||
await self.blacklist.add(token.raw, self.config.blacklist_ttl)
|
||||
|
||||
async def has_blacklist(self, token: JwtToken) -> bool:
|
||||
return await self.blacklist.has(token.raw)
|
||||
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
return getattr(self.config, key, default)
|
||||
|
||||
def _build_token(self, admin_id: str, token_type: str, ttl: int) -> str:
|
||||
now = int(time.time())
|
||||
payload: dict[str, Any] = {
|
||||
"iss": self.config.issuer,
|
||||
"jti": str(admin_id),
|
||||
"iat": now,
|
||||
"nbf": now,
|
||||
"exp": now + ttl,
|
||||
"token_type": token_type,
|
||||
"nonce": secrets.token_urlsafe(16),
|
||||
}
|
||||
if token_type == "refresh":
|
||||
payload["sub"] = "refresh"
|
||||
|
||||
header = {"typ": "JWT", "alg": "HS256"}
|
||||
signing_input = ".".join(
|
||||
[
|
||||
self._b64url_json(header),
|
||||
self._b64url_json(payload),
|
||||
]
|
||||
)
|
||||
signature = hmac.new(
|
||||
self.config.secret.encode("utf-8"),
|
||||
signing_input.encode("ascii"),
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
return f"{signing_input}.{self._b64url_encode(signature)}"
|
||||
|
||||
def _parse(self, raw_token: str) -> JwtToken:
|
||||
try:
|
||||
header_b64, payload_b64, signature_b64 = raw_token.split(".", 2)
|
||||
except ValueError as exc:
|
||||
raise JwtInvalidError("Malformed token") from exc
|
||||
|
||||
signing_input = f"{header_b64}.{payload_b64}".encode("ascii")
|
||||
expected = hmac.new(
|
||||
self.config.secret.encode("utf-8"),
|
||||
signing_input,
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
if not hmac.compare_digest(self._b64url_encode(expected), signature_b64):
|
||||
raise JwtInvalidError("Invalid signature")
|
||||
|
||||
try:
|
||||
header = json.loads(self._b64url_decode(header_b64))
|
||||
payload = json.loads(self._b64url_decode(payload_b64))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as exc:
|
||||
raise JwtInvalidError("Invalid token payload") from exc
|
||||
|
||||
if header.get("alg") != "HS256":
|
||||
raise JwtInvalidError("Unsupported algorithm")
|
||||
if not isinstance(payload, dict):
|
||||
raise JwtInvalidError("Invalid claims")
|
||||
|
||||
token = JwtToken(raw=raw_token, claims=payload)
|
||||
self._validate_time(token)
|
||||
return token
|
||||
|
||||
def _validate_time(self, token: JwtToken) -> None:
|
||||
now = int(time.time())
|
||||
exp = token.claims.get("exp")
|
||||
nbf = token.claims.get("nbf")
|
||||
iat = token.claims.get("iat")
|
||||
|
||||
if not isinstance(exp, int):
|
||||
raise JwtInvalidError("Missing exp")
|
||||
if exp <= now:
|
||||
raise JwtExpiredError("Token expired")
|
||||
if isinstance(nbf, int) and nbf > now:
|
||||
raise JwtInvalidError("Token not active")
|
||||
if isinstance(iat, int) and iat > now + 60:
|
||||
raise JwtInvalidError("Token issued in future")
|
||||
|
||||
def _validate_issuer(self, token: JwtToken) -> None:
|
||||
if token.claims.get("iss") != self.config.issuer:
|
||||
raise JwtInvalidError("Invalid issuer")
|
||||
|
||||
@staticmethod
|
||||
def _b64url_encode(raw: bytes) -> str:
|
||||
return base64.urlsafe_b64encode(raw).decode("ascii").rstrip("=")
|
||||
|
||||
@classmethod
|
||||
def _b64url_json(cls, value: dict[str, Any]) -> str:
|
||||
raw = json.dumps(value, separators=(",", ":"), sort_keys=True).encode("utf-8")
|
||||
return cls._b64url_encode(raw)
|
||||
|
||||
@staticmethod
|
||||
def _b64url_decode(value: str) -> str:
|
||||
padding = "=" * (-len(value) % 4)
|
||||
return base64.urlsafe_b64decode(value + padding).decode("utf-8")
|
||||
23
app/lib/jwt/token.py
Normal file
23
app/lib/jwt/token.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class JwtToken:
|
||||
raw: str
|
||||
claims: dict[str, Any]
|
||||
|
||||
@property
|
||||
def jwt_id(self) -> str:
|
||||
return str(self.claims.get("jti", ""))
|
||||
|
||||
@property
|
||||
def admin_id(self) -> int:
|
||||
try:
|
||||
return int(self.jwt_id)
|
||||
except ValueError:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def is_refresh(self) -> bool:
|
||||
return self.claims.get("sub") == "refresh" or self.claims.get("token_type") == "refresh"
|
||||
Reference in New Issue
Block a user