Initial FastAPI admin auth scaffold

This commit is contained in:
2026-06-05 17:10:30 +08:00
commit 5635da9ea5
65 changed files with 1407 additions and 0 deletions

1
app/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Application package."""

1
app/common/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Common helpers."""

3
app/common/context.py Normal file
View File

@@ -0,0 +1,3 @@
from contextvars import ContextVar
current_admin_id: ContextVar[int] = ContextVar("current_admin_id", default=0)

View File

@@ -0,0 +1 @@
"""Repository layer."""

View File

@@ -0,0 +1,79 @@
from datetime import UTC, datetime
from app.common.repository.base_repository import BaseRepository
from app.common.security.password_hasher import hash_password
from app.constants.model.admin_user.admin_user_status_code import AdminUserStatusCode
from app.model.admin_user import AdminUser
class AdminUserRepository(BaseRepository):
async def find_by_username(self, username: str) -> AdminUser | None:
row = await self.database.fetchone(
"""
SELECT id, username, password, user_type, nickname, phone, email,
status, login_ip, login_time, remark
FROM admin_user
WHERE username = ?
LIMIT 1
""",
(username,),
)
return AdminUser.from_row(row) if row else None
async def find_by_id(self, admin_id: int) -> AdminUser | None:
row = await self.database.fetchone(
"""
SELECT id, username, password, user_type, nickname, phone, email,
status, login_ip, login_time, remark
FROM admin_user
WHERE id = ?
LIMIT 1
""",
(admin_id,),
)
return AdminUser.from_row(row) if row else None
async def record_login(self, admin_id: int, login_ip: str) -> None:
now = self._now()
await self.database.execute(
"""
UPDATE admin_user
SET login_ip = ?, login_time = ?, updated_at = ?
WHERE id = ?
""",
(login_ip, now, now, admin_id),
)
async def ensure_seed_admin(self, username: str, password: str) -> None:
exists = await self.find_by_username(username)
if exists is not None:
return
now = self._now()
await self.database.execute(
"""
INSERT INTO admin_user (
username, password, user_type, nickname, phone, email, status,
login_ip, login_time, created_at, updated_at, remark
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
username,
hash_password(password),
"SuperAdmin",
"Super Admin",
"",
"",
int(AdminUserStatusCode.NORMAL),
"",
"",
now,
now,
"seeded by application startup",
),
)
@staticmethod
def _now() -> str:
return datetime.now(UTC).strftime("%Y-%m-%d %H:%M:%S")

View File

@@ -0,0 +1,6 @@
from app.core.database import Database
class BaseRepository:
def __init__(self, database: Database) -> None:
self.database = database

View File

@@ -0,0 +1 @@
"""Security helpers."""

View File

@@ -0,0 +1,40 @@
import base64
import hashlib
import hmac
import secrets
ALGORITHM = "pbkdf2_sha256"
ITERATIONS = 390_000
def _b64encode(raw: bytes) -> str:
return base64.urlsafe_b64encode(raw).decode("ascii").rstrip("=")
def _b64decode(value: str) -> bytes:
padding = "=" * (-len(value) % 4)
return base64.urlsafe_b64decode(value + padding)
def hash_password(password: str, iterations: int = ITERATIONS) -> str:
salt = secrets.token_bytes(16)
digest = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, iterations)
return f"{ALGORITHM}${iterations}${_b64encode(salt)}${_b64encode(digest)}"
def verify_password(password: str, stored_hash: str) -> bool:
try:
algorithm, iterations, salt, expected = stored_hash.split("$", 3)
except ValueError:
return hmac.compare_digest(password, stored_hash)
if algorithm != ALGORITHM:
return False
digest = hashlib.pbkdf2_hmac(
"sha256",
password.encode("utf-8"),
_b64decode(salt),
int(iterations),
)
return hmac.compare_digest(_b64encode(digest), expected)

View File

@@ -0,0 +1 @@
"""Application constants."""

View File

@@ -0,0 +1,6 @@
from enum import IntEnum
class AdminCode(IntEnum):
DISABLED = 30001
FORBIDDEN = 30002

View File

@@ -0,0 +1 @@
"""Model constants."""

View File

@@ -0,0 +1 @@
"""Admin user constants."""

View File

@@ -0,0 +1,12 @@
from enum import IntEnum
class AdminUserStatusCode(IntEnum):
NORMAL = 1
DISABLE = 2
def is_normal(self) -> bool:
return self is AdminUserStatusCode.NORMAL
def is_disable(self) -> bool:
return self is AdminUserStatusCode.DISABLE

View File

@@ -0,0 +1,12 @@
from enum import IntEnum
class ResultCode(IntEnum):
SUCCESS = 0
ERROR = 1
JWT_EXPIRED = 10001
JWT_ERROR = 10002
OLD_PASSWORD_ERROR = 10003
ACCOUNT_DEACTIVATING = 20001
ACCOUNT_DEACTIVATED = 20002
ACCOUNT_CANNOT_DEACTIVATE = 20003

View File

@@ -0,0 +1 @@
"""Controller layer."""

View File

@@ -0,0 +1 @@
"""Admin controllers."""

View File

@@ -0,0 +1,32 @@
from fastapi import APIRouter, Depends, Request
from app.core.dependencies import get_login_service, get_refresh_service
from app.lib.jwt.token import JwtToken
from app.middleware.admin.refresh_admin_token_middleware import RefreshAdminTokenMiddleware
from app.request.admin.login_request import LoginRequest
from app.service.admin.login.login_service import LoginService
from app.service.admin.login.refresh_service import RefreshService
router = APIRouter(prefix="/admin/login", tags=["admin-login"])
@router.post("/login")
async def login(
payload: LoginRequest,
request: Request,
# FastAPI 的 Depends 类似 Hyperf 里的容器注入。
# 请求进入这个接口时,框架会先调用 get_login_service()
# 把组装好的 LoginService 传进 service 参数。
service: LoginService = Depends(get_login_service),
) -> dict:
return await service.handle(payload, request)
@router.post("/refresh")
async def refresh(
# 这里把 RefreshAdminTokenMiddleware 当成依赖使用。
# FastAPI 会先执行 refresh token 校验,通过后才进入 controller。
token: JwtToken = Depends(RefreshAdminTokenMiddleware()),
service: RefreshService = Depends(get_refresh_service),
) -> dict:
return await service.handle(token)

View File

@@ -0,0 +1,21 @@
from fastapi import APIRouter, Depends
from app.core.dependencies import get_current_user_service
from app.middleware.admin.admin_token_middleware import AdminTokenMiddleware
from app.middleware.admin.permission_middleware import PermissionMiddleware
from app.service.admin.profile.current_user_service import CurrentUserService
router = APIRouter(prefix="/admin/profile", tags=["admin-profile"])
@router.get(
"/current",
dependencies=[
Depends(AdminTokenMiddleware()),
Depends(PermissionMiddleware()),
],
)
async def current(
service: CurrentUserService = Depends(get_current_user_service),
) -> dict:
return await service.handle()

View File

@@ -0,0 +1 @@
"""Frontend API controllers."""

View File

@@ -0,0 +1,10 @@
from fastapi import APIRouter
from app.lib.response.admin_return import AdminReturn
router = APIRouter(prefix="/api", tags=["api"])
@router.get("/health")
async def health() -> dict:
return AdminReturn().success("success", {"status": "ok"})

1
app/core/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Core application wiring."""

39
app/core/config.py Normal file
View File

@@ -0,0 +1,39 @@
from functools import lru_cache
from pathlib import Path
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
populate_by_name=True,
)
app_name: str = Field(default="py_server", alias="APP_NAME")
app_env: str = Field(default="local", alias="APP_ENV")
database_path: Path = Field(default=Path("storage/app.db"), alias="DATABASE_PATH")
admin_jwt_secret: str = Field(
default="dev_admin_secret_change_me",
alias="JWT_ADMIN_SECRET",
)
admin_jwt_ttl: int = Field(default=3600, alias="ADMIN_JWT_TTL")
admin_jwt_refresh_ttl: int = Field(default=7200, alias="ADMIN_JWT_REFRESH_TTL")
jwt_blacklist_ttl: int = Field(default=7201, alias="JWT_BLACKLIST_TTL")
api_jwt_secret: str = Field(default="dev_api_secret_change_me", alias="JWT_SECRET")
api_jwt_ttl: int = Field(default=3600, alias="JWT_TTL")
api_jwt_refresh_ttl: int = Field(default=7200, alias="JWT_REFRESH_TTL")
admin_seed_username: str = Field(default="admin", alias="ADMIN_SEED_USERNAME")
admin_seed_password: str = Field(default="admin", alias="ADMIN_SEED_PASSWORD")
cors_allow_origins: list[str] = Field(default=["*"], alias="CORS_ALLOW_ORIGINS")
@lru_cache
def get_settings() -> Settings:
return Settings()

87
app/core/database.py Normal file
View File

@@ -0,0 +1,87 @@
import asyncio
import sqlite3
from pathlib import Path
from typing import Any
class Database:
def __init__(self, path: Path) -> None:
self.path = path
async def initialize(self) -> None:
await asyncio.to_thread(self._initialize)
async def execute(self, sql: str, params: tuple[Any, ...] = ()) -> int:
return await asyncio.to_thread(self._execute, sql, params)
async def fetchone(
self,
sql: str,
params: tuple[Any, ...] = (),
) -> dict[str, Any] | None:
return await asyncio.to_thread(self._fetchone, sql, params)
async def fetchall(
self,
sql: str,
params: tuple[Any, ...] = (),
) -> list[dict[str, Any]]:
return await asyncio.to_thread(self._fetchall, sql, params)
def _connect(self) -> sqlite3.Connection:
self.path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(self.path)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA foreign_keys = ON")
return conn
def _initialize(self) -> None:
conn = self._connect()
try:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS admin_user (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL UNIQUE,
password TEXT NOT NULL,
user_type TEXT NOT NULL DEFAULT 'admin',
nickname TEXT NOT NULL DEFAULT '',
phone TEXT NOT NULL DEFAULT '',
email TEXT NOT NULL DEFAULT '',
status INTEGER NOT NULL DEFAULT 1,
login_ip TEXT NOT NULL DEFAULT '',
login_time TEXT NOT NULL DEFAULT '',
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
remark TEXT NOT NULL DEFAULT ''
)
"""
)
conn.commit()
finally:
conn.close()
def _execute(self, sql: str, params: tuple[Any, ...]) -> int:
conn = self._connect()
try:
cursor = conn.execute(sql, params)
conn.commit()
return int(cursor.lastrowid or 0)
finally:
conn.close()
def _fetchone(self, sql: str, params: tuple[Any, ...]) -> dict[str, Any] | None:
conn = self._connect()
try:
row = conn.execute(sql, params).fetchone()
return dict(row) if row else None
finally:
conn.close()
def _fetchall(self, sql: str, params: tuple[Any, ...]) -> list[dict[str, Any]]:
conn = self._connect()
try:
rows = conn.execute(sql, params).fetchall()
return [dict(row) for row in rows]
finally:
conn.close()

70
app/core/dependencies.py Normal file
View File

@@ -0,0 +1,70 @@
from functools import lru_cache
from app.common.repository.admin_user_repository import AdminUserRepository
from app.core.config import Settings, get_settings
from app.core.database import Database
from app.lib.jwt.blacklist import InMemoryTokenBlacklist
from app.lib.jwt.factory import JwtFactory
from app.lib.response.admin_return import AdminReturn
from app.service.admin.login.login_service import LoginService
from app.service.admin.login.refresh_service import RefreshService
from app.service.admin.profile.current_user_service import CurrentUserService
from app.service.base_token_service import BaseTokenService
# lru_cache 会缓存函数第一次创建出来的对象。
# 这里用它把 Database、JwtFactory、TokenService 等依赖做成应用级单例,
# 类似 Hyperf 从容器里反复 get 同一个共享服务。
@lru_cache
def get_database() -> Database:
return Database(get_settings().database_path)
@lru_cache
def get_token_blacklist() -> InMemoryTokenBlacklist:
return InMemoryTokenBlacklist()
@lru_cache
def get_jwt_factory() -> JwtFactory:
return JwtFactory(get_settings(), get_token_blacklist())
@lru_cache
def get_token_service() -> BaseTokenService:
return BaseTokenService(get_jwt_factory())
@lru_cache
def get_admin_return() -> AdminReturn:
return AdminReturn()
@lru_cache
def get_admin_user_repository() -> AdminUserRepository:
return AdminUserRepository(get_database())
def get_login_service() -> LoginService:
return LoginService(
get_admin_user_repository(),
get_token_service(),
get_admin_return(),
)
def get_refresh_service() -> RefreshService:
return RefreshService(get_token_service(), get_admin_return())
def get_current_user_service() -> CurrentUserService:
return CurrentUserService(get_admin_user_repository(), get_admin_return())
async def bootstrap_database(settings: Settings | None = None) -> None:
settings = settings or get_settings()
await get_database().initialize()
await get_admin_user_repository().ensure_seed_admin(
settings.admin_seed_username,
settings.admin_seed_password,
)

View File

@@ -0,0 +1 @@
"""Exception layer."""

View File

@@ -0,0 +1,14 @@
from app.constants.result_code import ResultCode
class ErrException(Exception):
def __init__(
self,
message: str = "failed",
code: int | ResultCode = ResultCode.ERROR,
data: dict | None = None,
) -> None:
super().__init__(message)
self.message = message
self.code = int(code)
self.data = data or {}

34
app/exception/handler.py Normal file
View File

@@ -0,0 +1,34 @@
from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from app.constants.result_code import ResultCode
from app.exception.err_exception import ErrException
from app.lib.response.admin_return import AdminReturn
async def err_exception_handler(request: Request, exc: ErrException) -> JSONResponse:
return JSONResponse(
status_code=200,
content=AdminReturn().error(exc.message, exc.code, exc.data),
)
async def validation_exception_handler(
request: Request,
exc: RequestValidationError,
) -> JSONResponse:
first_error = exc.errors()[0] if exc.errors() else {}
location = ".".join(str(item) for item in first_error.get("loc", []))
message = first_error.get("msg", "参数错误")
if location:
message = f"{location}: {message}"
return JSONResponse(
status_code=200,
content=AdminReturn().error(message, ResultCode.ERROR),
)
def register_exception_handlers(app: FastAPI) -> None:
app.add_exception_handler(ErrException, err_exception_handler)
app.add_exception_handler(RequestValidationError, validation_exception_handler)

1
app/lib/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Library layer."""

1
app/lib/jwt/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""JWT helpers."""

25
app/lib/jwt/blacklist.py Normal file
View File

@@ -0,0 +1,25 @@
import asyncio
import time
class InMemoryTokenBlacklist:
def __init__(self) -> None:
self._items: dict[str, float] = {}
self._lock = asyncio.Lock()
async def add(self, token: str, ttl: int) -> None:
now = time.time()
async with self._lock:
self._cleanup(now)
self._items[token] = now + ttl
async def has(self, token: str) -> bool:
now = time.time()
async with self._lock:
self._cleanup(now)
return token in self._items
def _cleanup(self, now: float) -> None:
expired = [token for token, expires_at in self._items.items() if expires_at <= now]
for token in expired:
self._items.pop(token, None)

10
app/lib/jwt/exceptions.py Normal file
View File

@@ -0,0 +1,10 @@
class JwtError(Exception):
pass
class JwtExpiredError(JwtError):
pass
class JwtInvalidError(JwtError):
pass

34
app/lib/jwt/factory.py Normal file
View File

@@ -0,0 +1,34 @@
from app.core.config import Settings
from app.lib.jwt.blacklist import InMemoryTokenBlacklist
from app.lib.jwt.jwt import Jwt, JwtSceneConfig
class JwtFactory:
def __init__(self, settings: Settings, blacklist: InMemoryTokenBlacklist) -> None:
self.settings = settings
self.blacklist = blacklist
def get(self, scene: str = "default") -> Jwt:
if scene == "admin":
return Jwt(
JwtSceneConfig(
secret=self.settings.admin_jwt_secret,
issuer=f"{self.settings.app_name}_admin",
ttl=self.settings.admin_jwt_ttl,
refresh_ttl=self.settings.admin_jwt_refresh_ttl,
blacklist_ttl=self.settings.jwt_blacklist_ttl,
),
self.blacklist,
)
scene_name = "default" if scene == "default" else scene
return Jwt(
JwtSceneConfig(
secret=self.settings.api_jwt_secret,
issuer=f"{self.settings.app_name}_{scene_name}",
ttl=self.settings.api_jwt_ttl,
refresh_ttl=self.settings.api_jwt_refresh_ttl,
blacklist_ttl=self.settings.jwt_blacklist_ttl,
),
self.blacklist,
)

151
app/lib/jwt/jwt.py Normal file
View File

@@ -0,0 +1,151 @@
import base64
import hashlib
import hmac
import json
import secrets
import time
from dataclasses import dataclass
from typing import Any
from app.lib.jwt.blacklist import InMemoryTokenBlacklist
from app.lib.jwt.exceptions import JwtExpiredError, JwtInvalidError
from app.lib.jwt.token import JwtToken
@dataclass(frozen=True, slots=True)
class JwtSceneConfig:
secret: str
issuer: str
ttl: int
refresh_ttl: int
blacklist_ttl: int
class Jwt:
def __init__(self, config: JwtSceneConfig, blacklist: InMemoryTokenBlacklist) -> None:
self.config = config
self.blacklist = blacklist
@property
def issuer(self) -> str:
return self.config.issuer
def builder_access_token(self, admin_id: str) -> str:
return self._build_token(admin_id, token_type="access", ttl=self.config.ttl)
def builder_refresh_token(self, admin_id: str) -> str:
return self._build_token(admin_id, token_type="refresh", ttl=self.config.refresh_ttl)
async def parser_access_token(self, raw_token: str) -> JwtToken:
token = self._parse(raw_token)
if token.is_refresh:
raise JwtInvalidError("Token is a refresh token")
self._validate_issuer(token)
return token
async def parser_refresh_token(self, raw_token: str) -> JwtToken:
token = self._parse(raw_token)
if not token.is_refresh:
raise JwtInvalidError("Token is not a refresh token")
self._validate_issuer(token)
return token
async def add_blacklist(self, token: JwtToken) -> None:
await self.blacklist.add(token.raw, self.config.blacklist_ttl)
async def has_blacklist(self, token: JwtToken) -> bool:
return await self.blacklist.has(token.raw)
def get_config(self, key: str, default: Any = None) -> Any:
return getattr(self.config, key, default)
def _build_token(self, admin_id: str, token_type: str, ttl: int) -> str:
now = int(time.time())
payload: dict[str, Any] = {
"iss": self.config.issuer,
"jti": str(admin_id),
"iat": now,
"nbf": now,
"exp": now + ttl,
"token_type": token_type,
"nonce": secrets.token_urlsafe(16),
}
if token_type == "refresh":
payload["sub"] = "refresh"
header = {"typ": "JWT", "alg": "HS256"}
signing_input = ".".join(
[
self._b64url_json(header),
self._b64url_json(payload),
]
)
signature = hmac.new(
self.config.secret.encode("utf-8"),
signing_input.encode("ascii"),
hashlib.sha256,
).digest()
return f"{signing_input}.{self._b64url_encode(signature)}"
def _parse(self, raw_token: str) -> JwtToken:
try:
header_b64, payload_b64, signature_b64 = raw_token.split(".", 2)
except ValueError as exc:
raise JwtInvalidError("Malformed token") from exc
signing_input = f"{header_b64}.{payload_b64}".encode("ascii")
expected = hmac.new(
self.config.secret.encode("utf-8"),
signing_input,
hashlib.sha256,
).digest()
if not hmac.compare_digest(self._b64url_encode(expected), signature_b64):
raise JwtInvalidError("Invalid signature")
try:
header = json.loads(self._b64url_decode(header_b64))
payload = json.loads(self._b64url_decode(payload_b64))
except (json.JSONDecodeError, UnicodeDecodeError) as exc:
raise JwtInvalidError("Invalid token payload") from exc
if header.get("alg") != "HS256":
raise JwtInvalidError("Unsupported algorithm")
if not isinstance(payload, dict):
raise JwtInvalidError("Invalid claims")
token = JwtToken(raw=raw_token, claims=payload)
self._validate_time(token)
return token
def _validate_time(self, token: JwtToken) -> None:
now = int(time.time())
exp = token.claims.get("exp")
nbf = token.claims.get("nbf")
iat = token.claims.get("iat")
if not isinstance(exp, int):
raise JwtInvalidError("Missing exp")
if exp <= now:
raise JwtExpiredError("Token expired")
if isinstance(nbf, int) and nbf > now:
raise JwtInvalidError("Token not active")
if isinstance(iat, int) and iat > now + 60:
raise JwtInvalidError("Token issued in future")
def _validate_issuer(self, token: JwtToken) -> None:
if token.claims.get("iss") != self.config.issuer:
raise JwtInvalidError("Invalid issuer")
@staticmethod
def _b64url_encode(raw: bytes) -> str:
return base64.urlsafe_b64encode(raw).decode("ascii").rstrip("=")
@classmethod
def _b64url_json(cls, value: dict[str, Any]) -> str:
raw = json.dumps(value, separators=(",", ":"), sort_keys=True).encode("utf-8")
return cls._b64url_encode(raw)
@staticmethod
def _b64url_decode(value: str) -> str:
padding = "=" * (-len(value) % 4)
return base64.urlsafe_b64decode(value + padding).decode("utf-8")

23
app/lib/jwt/token.py Normal file
View File

@@ -0,0 +1,23 @@
from dataclasses import dataclass
from typing import Any
@dataclass(frozen=True, slots=True)
class JwtToken:
raw: str
claims: dict[str, Any]
@property
def jwt_id(self) -> str:
return str(self.claims.get("jti", ""))
@property
def admin_id(self) -> int:
try:
return int(self.jwt_id)
except ValueError:
return 0
@property
def is_refresh(self) -> bool:
return self.claims.get("sub") == "refresh" or self.claims.get("token_type") == "refresh"

View File

@@ -0,0 +1 @@
"""Response helpers."""

View File

@@ -0,0 +1,6 @@
from app.lib.response.common_return import CommonReturn
class AdminReturn(CommonReturn):
def after_success(self, response: dict) -> dict:
return response

View File

@@ -0,0 +1,34 @@
from app.constants.result_code import ResultCode
class CommonReturn:
def success(
self,
message: str = "success",
data: dict | list | None = None,
code: int | ResultCode = ResultCode.SUCCESS,
) -> dict:
return self.after_success(
{
"code": int(code),
"message": message,
"data": data if data is not None else {},
}
)
def error(
self,
message: str = "failed",
code: int | ResultCode = ResultCode.ERROR,
data: dict | None = None,
) -> dict:
return self.after_success(
{
"code": int(code),
"message": message,
"data": data or {},
}
)
def after_success(self, response: dict) -> dict:
raise NotImplementedError

44
app/main.py Normal file
View File

@@ -0,0 +1,44 @@
from contextlib import asynccontextmanager
from typing import AsyncIterator
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.controller.admin.login_controller import router as admin_login_router
from app.controller.admin.profile_controller import router as admin_profile_router
from app.controller.api.health_controller import router as api_health_router
from app.core.config import get_settings
from app.core.dependencies import bootstrap_database
from app.exception.handler import register_exception_handlers
from app.lib.response.admin_return import AdminReturn
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
await bootstrap_database()
yield
def create_app() -> FastAPI:
settings = get_settings()
application = FastAPI(title=settings.app_name, lifespan=lifespan)
application.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_allow_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
register_exception_handlers(application)
application.include_router(admin_login_router)
application.include_router(admin_profile_router)
application.include_router(api_health_router)
@application.get("/")
async def index() -> dict:
return AdminReturn().success("success", {"name": settings.app_name})
return application
app = create_app()

View File

@@ -0,0 +1 @@
"""Middleware and route dependencies."""

View File

@@ -0,0 +1 @@
"""Admin middleware."""

View File

@@ -0,0 +1,16 @@
from fastapi import Request
from app.common.context import current_admin_id
from app.lib.jwt.jwt import Jwt
from app.lib.jwt.token import JwtToken
from app.middleware.token.abstract_token_middleware import AbstractTokenMiddleware
class AdminTokenMiddleware(AbstractTokenMiddleware):
async def parser_token(self, jwt: Jwt, raw_token: str) -> JwtToken:
return await jwt.parser_access_token(raw_token)
def set_context(self, request: Request, token: JwtToken) -> None:
admin_id = token.admin_id
current_admin_id.set(admin_id)
request.state.current_admin_id = admin_id

View File

@@ -0,0 +1,34 @@
from collections.abc import Iterable
from fastapi import Depends, Request
from app.common.context import current_admin_id
from app.common.repository.admin_user_repository import AdminUserRepository
from app.constants.admin_code import AdminCode
from app.constants.model.admin_user.admin_user_status_code import AdminUserStatusCode
from app.core.dependencies import get_admin_user_repository
from app.exception.err_exception import ErrException
class PermissionMiddleware:
def __init__(self, permissions: Iterable[str] | None = None) -> None:
self.permissions = tuple(permissions or ())
async def __call__(
self,
request: Request,
user_repository: AdminUserRepository = Depends(get_admin_user_repository),
) -> None:
admin_id = getattr(request.state, "current_admin_id", current_admin_id.get())
if admin_id <= 0:
raise ErrException("账户不存在")
admin_user = await user_repository.find_by_id(admin_id)
if admin_user is None:
raise ErrException("账户不存在")
if admin_user.status == AdminUserStatusCode.DISABLE:
raise ErrException("账号已禁用", AdminCode.DISABLED)
request.state.current_admin_user = admin_user
if self.permissions and admin_user.user_type != "SuperAdmin":
raise ErrException("暂无权限", AdminCode.FORBIDDEN)

View File

@@ -0,0 +1,13 @@
from fastapi import Request
from app.lib.jwt.jwt import Jwt
from app.lib.jwt.token import JwtToken
from app.middleware.token.abstract_token_middleware import AbstractTokenMiddleware
class RefreshAdminTokenMiddleware(AbstractTokenMiddleware):
async def parser_token(self, jwt: Jwt, raw_token: str) -> JwtToken:
return await jwt.parser_refresh_token(raw_token)
def set_context(self, request: Request, token: JwtToken) -> None:
return None

View File

@@ -0,0 +1 @@
"""Token middleware."""

View File

@@ -0,0 +1,72 @@
from fastapi import Depends, Request
from app.constants.result_code import ResultCode
from app.core.dependencies import get_token_service
from app.exception.err_exception import ErrException
from app.lib.jwt.exceptions import JwtError, JwtExpiredError
from app.lib.jwt.jwt import Jwt
from app.lib.jwt.token import JwtToken
from app.service.base_token_service import BaseTokenService
class AbstractTokenMiddleware:
scene = "admin"
async def __call__(
self,
request: Request,
token_service: BaseTokenService = Depends(get_token_service),
) -> JwtToken:
raw_token = self.get_token(request)
jwt = token_service.get_jwt(self.scene)
try:
token = await self.parser_token(jwt, raw_token)
await token_service.check_jwt(jwt, token)
self.check_issuer(jwt, token)
except JwtExpiredError as exc:
raise ErrException(
"token过期",
ResultCode.JWT_EXPIRED,
{"err_msg": str(exc)},
) from exc
except JwtError as exc:
raise ErrException(
"token错误",
ResultCode.JWT_ERROR,
{"err_msg": str(exc)},
) from exc
self.set_context(request, token)
request.state.token = token
return token
async def parser_token(self, jwt: Jwt, raw_token: str) -> JwtToken:
raise NotImplementedError
def set_context(self, request: Request, token: JwtToken) -> None:
raise NotImplementedError
@staticmethod
def check_issuer(jwt: Jwt, token: JwtToken) -> None:
if token.claims.get("iss") != jwt.issuer:
raise ErrException("token错误", ResultCode.JWT_ERROR)
@staticmethod
def get_token(request: Request) -> str:
authorization = request.headers.get("authorization")
if authorization:
scheme, _, token = authorization.partition(" ")
if scheme.lower() == "bearer" and token:
return token.strip()
return authorization.replace("Bearer ", "", 1).strip()
token_header = request.headers.get("token")
if token_header:
return token_header.strip()
token_query = request.query_params.get("token")
if token_query:
return token_query.strip()
raise ErrException("token缺失", ResultCode.JWT_ERROR)

1
app/model/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Model layer."""

53
app/model/admin_user.py Normal file
View File

@@ -0,0 +1,53 @@
from dataclasses import dataclass
from typing import Any
from app.common.security.password_hasher import verify_password
from app.constants.model.admin_user.admin_user_status_code import AdminUserStatusCode
@dataclass(slots=True)
class AdminUser:
id: int
username: str
password: str
user_type: str
nickname: str
phone: str
email: str
status: AdminUserStatusCode
login_ip: str
login_time: str
remark: str
@classmethod
def from_row(cls, row: dict[str, Any]) -> "AdminUser":
return cls(
id=int(row["id"]),
username=str(row["username"]),
password=str(row["password"]),
user_type=str(row["user_type"]),
nickname=str(row["nickname"]),
phone=str(row["phone"]),
email=str(row["email"]),
status=AdminUserStatusCode(int(row["status"])),
login_ip=str(row["login_ip"]),
login_time=str(row["login_time"]),
remark=str(row["remark"]),
)
def verify_password(self, password: str) -> bool:
return verify_password(password, self.password)
def to_public_dict(self) -> dict:
return {
"id": self.id,
"username": self.username,
"user_type": self.user_type,
"nickname": self.nickname,
"phone": self.phone,
"email": self.email,
"status": int(self.status),
"login_ip": self.login_ip,
"login_time": self.login_time,
"remark": self.remark,
}

1
app/request/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Request schemas."""

View File

@@ -0,0 +1 @@
"""Admin request schemas."""

View File

@@ -0,0 +1,6 @@
from pydantic import BaseModel, Field
class LoginRequest(BaseModel):
username: str = Field(min_length=1)
password: str = Field(min_length=1)

1
app/service/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Service layer."""

View File

@@ -0,0 +1 @@
"""Admin services."""

View File

@@ -0,0 +1,15 @@
from app.common.context import current_admin_id
from app.exception.err_exception import ErrException
from app.lib.response.admin_return import AdminReturn
class BaseAdminService:
def __init__(self, admin_return: AdminReturn) -> None:
self.admin_return = admin_return
@property
def admin_id(self) -> int:
admin_id = current_admin_id.get()
if admin_id <= 0:
raise ErrException("账户不存在")
return admin_id

View File

@@ -0,0 +1 @@
"""Admin login services."""

View File

@@ -0,0 +1,56 @@
import asyncio
from fastapi import Request
from app.common.repository.admin_user_repository import AdminUserRepository
from app.constants.model.admin_user.admin_user_status_code import AdminUserStatusCode
from app.exception.err_exception import ErrException
from app.lib.response.admin_return import AdminReturn
from app.request.admin.login_request import LoginRequest
from app.service.admin.base_admin_service import BaseAdminService
from app.service.base_token_service import BaseTokenService
class LoginService(BaseAdminService):
def __init__(
self,
user_repository: AdminUserRepository,
token_service: BaseTokenService,
admin_return: AdminReturn,
) -> None:
super().__init__(admin_return)
self.user_repository = user_repository
self.token_service = token_service
async def handle(self, payload: LoginRequest, request: Request) -> dict:
admin_info = await self.user_repository.find_by_username(payload.username)
if admin_info is None:
raise ErrException("后台管理员不存在")
password_valid = await asyncio.to_thread(
admin_info.verify_password,
payload.password,
)
if not password_valid:
raise ErrException("密码错误")
if admin_info.status == AdminUserStatusCode.DISABLE:
raise ErrException("用户已禁用")
await self.user_repository.record_login(admin_info.id, self._client_ip(request))
jwt = self.token_service.get_jwt("admin")
return self.admin_return.success(
"success",
{
"access_token": jwt.builder_access_token(str(admin_info.id)),
"refresh_token": jwt.builder_refresh_token(str(admin_info.id)),
"expire_at": int(jwt.get_config("ttl", 0)),
},
)
@staticmethod
def _client_ip(request: Request) -> str:
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",", 1)[0].strip()
return request.client.host if request.client else ""

View File

@@ -0,0 +1,22 @@
from app.lib.jwt.token import JwtToken
from app.lib.response.admin_return import AdminReturn
from app.service.admin.base_admin_service import BaseAdminService
from app.service.base_token_service import BaseTokenService
class RefreshService(BaseAdminService):
def __init__(self, token_service: BaseTokenService, admin_return: AdminReturn) -> None:
super().__init__(admin_return)
self.token_service = token_service
async def handle(self, token: JwtToken) -> dict:
return self.admin_return.success("success", await self.refresh_token(token))
async def refresh_token(self, token: JwtToken) -> dict[str, int | str]:
jwt = self.token_service.get_jwt("admin")
await jwt.add_blacklist(token)
return {
"access_token": jwt.builder_access_token(token.jwt_id),
"refresh_token": jwt.builder_refresh_token(token.jwt_id),
"expire_at": int(jwt.get_config("ttl", 0)),
}

View File

@@ -0,0 +1 @@
"""Admin profile services."""

View File

@@ -0,0 +1,20 @@
from app.common.repository.admin_user_repository import AdminUserRepository
from app.exception.err_exception import ErrException
from app.lib.response.admin_return import AdminReturn
from app.service.admin.base_admin_service import BaseAdminService
class CurrentUserService(BaseAdminService):
def __init__(
self,
user_repository: AdminUserRepository,
admin_return: AdminReturn,
) -> None:
super().__init__(admin_return)
self.user_repository = user_repository
async def handle(self) -> dict:
admin_user = await self.user_repository.find_by_id(self.admin_id)
if admin_user is None:
raise ErrException("账户不存在")
return self.admin_return.success("success", admin_user.to_public_dict())

View File

@@ -0,0 +1,17 @@
from app.constants.result_code import ResultCode
from app.exception.err_exception import ErrException
from app.lib.jwt.factory import JwtFactory
from app.lib.jwt.jwt import Jwt
from app.lib.jwt.token import JwtToken
class BaseTokenService:
def __init__(self, jwt_factory: JwtFactory) -> None:
self.jwt_factory = jwt_factory
def get_jwt(self, scene: str = "admin") -> Jwt:
return self.jwt_factory.get(scene)
async def check_jwt(self, jwt: Jwt, token: JwtToken) -> None:
if await jwt.has_blacklist(token):
raise ErrException("token已过期", ResultCode.JWT_EXPIRED)