152 lines
5.1 KiB
Python
152 lines
5.1 KiB
Python
import base64
|
|
import hashlib
|
|
import hmac
|
|
import json
|
|
import secrets
|
|
import time
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
from app.lib.jwt.blacklist import InMemoryTokenBlacklist
|
|
from app.lib.jwt.exceptions import JwtExpiredError, JwtInvalidError
|
|
from app.lib.jwt.token import JwtToken
|
|
|
|
|
|
@dataclass(frozen=True, slots=True)
|
|
class JwtSceneConfig:
|
|
secret: str
|
|
issuer: str
|
|
ttl: int
|
|
refresh_ttl: int
|
|
blacklist_ttl: int
|
|
|
|
|
|
class Jwt:
|
|
def __init__(self, config: JwtSceneConfig, blacklist: InMemoryTokenBlacklist) -> None:
|
|
self.config = config
|
|
self.blacklist = blacklist
|
|
|
|
@property
|
|
def issuer(self) -> str:
|
|
return self.config.issuer
|
|
|
|
def builder_access_token(self, admin_id: str) -> str:
|
|
return self._build_token(admin_id, token_type="access", ttl=self.config.ttl)
|
|
|
|
def builder_refresh_token(self, admin_id: str) -> str:
|
|
return self._build_token(admin_id, token_type="refresh", ttl=self.config.refresh_ttl)
|
|
|
|
async def parser_access_token(self, raw_token: str) -> JwtToken:
|
|
token = self._parse(raw_token)
|
|
if token.is_refresh:
|
|
raise JwtInvalidError("Token is a refresh token")
|
|
self._validate_issuer(token)
|
|
return token
|
|
|
|
async def parser_refresh_token(self, raw_token: str) -> JwtToken:
|
|
token = self._parse(raw_token)
|
|
if not token.is_refresh:
|
|
raise JwtInvalidError("Token is not a refresh token")
|
|
self._validate_issuer(token)
|
|
return token
|
|
|
|
async def add_blacklist(self, token: JwtToken) -> None:
|
|
await self.blacklist.add(token.raw, self.config.blacklist_ttl)
|
|
|
|
async def has_blacklist(self, token: JwtToken) -> bool:
|
|
return await self.blacklist.has(token.raw)
|
|
|
|
def get_config(self, key: str, default: Any = None) -> Any:
|
|
return getattr(self.config, key, default)
|
|
|
|
def _build_token(self, admin_id: str, token_type: str, ttl: int) -> str:
|
|
now = int(time.time())
|
|
payload: dict[str, Any] = {
|
|
"iss": self.config.issuer,
|
|
"jti": str(admin_id),
|
|
"iat": now,
|
|
"nbf": now,
|
|
"exp": now + ttl,
|
|
"token_type": token_type,
|
|
"nonce": secrets.token_urlsafe(16),
|
|
}
|
|
if token_type == "refresh":
|
|
payload["sub"] = "refresh"
|
|
|
|
header = {"typ": "JWT", "alg": "HS256"}
|
|
signing_input = ".".join(
|
|
[
|
|
self._b64url_json(header),
|
|
self._b64url_json(payload),
|
|
]
|
|
)
|
|
signature = hmac.new(
|
|
self.config.secret.encode("utf-8"),
|
|
signing_input.encode("ascii"),
|
|
hashlib.sha256,
|
|
).digest()
|
|
return f"{signing_input}.{self._b64url_encode(signature)}"
|
|
|
|
def _parse(self, raw_token: str) -> JwtToken:
|
|
try:
|
|
header_b64, payload_b64, signature_b64 = raw_token.split(".", 2)
|
|
except ValueError as exc:
|
|
raise JwtInvalidError("Malformed token") from exc
|
|
|
|
signing_input = f"{header_b64}.{payload_b64}".encode("ascii")
|
|
expected = hmac.new(
|
|
self.config.secret.encode("utf-8"),
|
|
signing_input,
|
|
hashlib.sha256,
|
|
).digest()
|
|
if not hmac.compare_digest(self._b64url_encode(expected), signature_b64):
|
|
raise JwtInvalidError("Invalid signature")
|
|
|
|
try:
|
|
header = json.loads(self._b64url_decode(header_b64))
|
|
payload = json.loads(self._b64url_decode(payload_b64))
|
|
except (json.JSONDecodeError, UnicodeDecodeError) as exc:
|
|
raise JwtInvalidError("Invalid token payload") from exc
|
|
|
|
if header.get("alg") != "HS256":
|
|
raise JwtInvalidError("Unsupported algorithm")
|
|
if not isinstance(payload, dict):
|
|
raise JwtInvalidError("Invalid claims")
|
|
|
|
token = JwtToken(raw=raw_token, claims=payload)
|
|
self._validate_time(token)
|
|
return token
|
|
|
|
def _validate_time(self, token: JwtToken) -> None:
|
|
now = int(time.time())
|
|
exp = token.claims.get("exp")
|
|
nbf = token.claims.get("nbf")
|
|
iat = token.claims.get("iat")
|
|
|
|
if not isinstance(exp, int):
|
|
raise JwtInvalidError("Missing exp")
|
|
if exp <= now:
|
|
raise JwtExpiredError("Token expired")
|
|
if isinstance(nbf, int) and nbf > now:
|
|
raise JwtInvalidError("Token not active")
|
|
if isinstance(iat, int) and iat > now + 60:
|
|
raise JwtInvalidError("Token issued in future")
|
|
|
|
def _validate_issuer(self, token: JwtToken) -> None:
|
|
if token.claims.get("iss") != self.config.issuer:
|
|
raise JwtInvalidError("Invalid issuer")
|
|
|
|
@staticmethod
|
|
def _b64url_encode(raw: bytes) -> str:
|
|
return base64.urlsafe_b64encode(raw).decode("ascii").rstrip("=")
|
|
|
|
@classmethod
|
|
def _b64url_json(cls, value: dict[str, Any]) -> str:
|
|
raw = json.dumps(value, separators=(",", ":"), sort_keys=True).encode("utf-8")
|
|
return cls._b64url_encode(raw)
|
|
|
|
@staticmethod
|
|
def _b64url_decode(value: str) -> str:
|
|
padding = "=" * (-len(value) % 4)
|
|
return base64.urlsafe_b64decode(value + padding).decode("utf-8")
|