Initial FastAPI admin auth scaffold

This commit is contained in:
2026-06-05 17:10:30 +08:00
commit 5635da9ea5
65 changed files with 1407 additions and 0 deletions

11
.env.example Normal file
View File

@@ -0,0 +1,11 @@
APP_NAME=py_server
APP_ENV=local
DATABASE_PATH=storage/app.db
JWT_ADMIN_SECRET=change_me_admin_secret
ADMIN_JWT_TTL=3600
ADMIN_JWT_REFRESH_TTL=7200
JWT_BLACKLIST_TTL=7201
ADMIN_SEED_USERNAME=admin
ADMIN_SEED_PASSWORD=admin

7
.gitignore vendored Normal file
View File

@@ -0,0 +1,7 @@
.env
__pycache__/
*.py[cod]
.pytest_cache/
.ruff_cache/
.idea/
storage/*.db

145
README.md Normal file
View File

@@ -0,0 +1,145 @@
# py_server
这是一个 FastAPI 异步 API 项目骨架,分层方式参考了 Hyperf 项目里的 `server/app`
## 分层对应
| Hyperf | FastAPI 当前项目 | 作用 |
| --- | --- | --- |
| `Controller/Admin` | `app/controller/admin` | 接收请求,调用 service |
| `Service/Admin` | `app/service/admin` | 业务逻辑 |
| `Common/Repository` | `app/common/repository` | 数据访问 |
| `Model` | `app/model` | 数据模型 |
| `Middleware/Admin` | `app/middleware/admin` | admin token 和权限校验 |
| `Lib/Jwt` | `app/lib/jwt` | JWT 签发、解析、黑名单 |
| `Request/Admin` | `app/request/admin` | 请求参数校验 |
| `Constants` | `app/constants` | 状态码和业务常量 |
| `Lib/Return` | `app/lib/response` | 统一返回结构 |
> Python 里 `return` 是关键字,所以参考 Hyperf 的 `Lib/Return` 在这里命名为 `lib/response`。
## 启动
```bash
.venv/bin/uvicorn app.main:app --reload
```
默认启动时会自动创建 SQLite 开发库,并创建一个 admin 用户:
- username: `admin`
- password: `admin`
本地配置可以复制:
```bash
cp .env.example .env
```
## Admin 登录接口
### 登录
```http
POST /admin/login/login
Content-Type: application/json
{
"username": "admin",
"password": "admin"
}
```
返回:
```json
{
"code": 0,
"message": "success",
"data": {
"access_token": "...",
"refresh_token": "...",
"expire_at": 3600
}
}
```
### 刷新 token
```http
POST /admin/login/refresh
Authorization: Bearer <refresh_token>
```
刷新成功后会签发新的 `access_token``refresh_token`,并把旧的 `refresh_token` 加入黑名单,防止重复刷新。
### 当前 admin 用户
```http
GET /admin/profile/current
Authorization: Bearer <access_token>
```
## JWT 逻辑
当前实现和参考 Hyperf 项目保持同样思路:
1. 登录成功后同时签发 `access_token``refresh_token`
2. 普通 admin 接口只接受 `access_token`
3. `/admin/login/refresh` 只接受 `refresh_token`
4. refresh 成功后,把旧 refresh token 加入黑名单。
5. token 会校验签名、过期时间、issuer 和 token 类型。
token 读取顺序和 Hyperf 中间件一致:
1. `Authorization: Bearer <token>`
2. `token` header
3. query string: `?token=...`
## FastAPI 依赖注入
Controller 里这段:
```python
service: LoginService = Depends(get_login_service)
```
可以理解为 FastAPI 版的 Hyperf 容器注入。
请求进来时FastAPI 会先调用 `get_login_service()`,把创建好的 `LoginService` 传给 `service` 参数。
`get_login_service()` 内部会继续组装:
- `AdminUserRepository`
- `BaseTokenService`
- `AdminReturn`
所以 controller 只负责接收请求和调用 service。
## `@lru_cache` 的作用
`app/core/dependencies.py` 里的:
```python
@lru_cache
def get_database() -> Database:
return Database(get_settings().database_path)
```
表示第一次调用时创建对象,后面再次调用时直接复用第一次创建的对象。
在这个项目里,它的作用类似 Hyperf 容器里的共享服务/单例服务。比如 JWT 黑名单必须复用同一个对象,否则旧 refresh token 加入黑名单后,下一次请求就查不到了。
## 测试
```bash
.venv/bin/python -m compileall app tests
.venv/bin/python -m unittest tests.test_admin_login_flow
```
测试覆盖:
- admin 登录成功
- access token 访问当前用户成功
- refresh token 换新 token 成功
- 旧 refresh token 复用失败
- access token 调用 refresh 接口失败

1
app/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Application package."""

1
app/common/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Common helpers."""

3
app/common/context.py Normal file
View File

@@ -0,0 +1,3 @@
from contextvars import ContextVar
current_admin_id: ContextVar[int] = ContextVar("current_admin_id", default=0)

View File

@@ -0,0 +1 @@
"""Repository layer."""

View File

@@ -0,0 +1,79 @@
from datetime import UTC, datetime
from app.common.repository.base_repository import BaseRepository
from app.common.security.password_hasher import hash_password
from app.constants.model.admin_user.admin_user_status_code import AdminUserStatusCode
from app.model.admin_user import AdminUser
class AdminUserRepository(BaseRepository):
async def find_by_username(self, username: str) -> AdminUser | None:
row = await self.database.fetchone(
"""
SELECT id, username, password, user_type, nickname, phone, email,
status, login_ip, login_time, remark
FROM admin_user
WHERE username = ?
LIMIT 1
""",
(username,),
)
return AdminUser.from_row(row) if row else None
async def find_by_id(self, admin_id: int) -> AdminUser | None:
row = await self.database.fetchone(
"""
SELECT id, username, password, user_type, nickname, phone, email,
status, login_ip, login_time, remark
FROM admin_user
WHERE id = ?
LIMIT 1
""",
(admin_id,),
)
return AdminUser.from_row(row) if row else None
async def record_login(self, admin_id: int, login_ip: str) -> None:
now = self._now()
await self.database.execute(
"""
UPDATE admin_user
SET login_ip = ?, login_time = ?, updated_at = ?
WHERE id = ?
""",
(login_ip, now, now, admin_id),
)
async def ensure_seed_admin(self, username: str, password: str) -> None:
exists = await self.find_by_username(username)
if exists is not None:
return
now = self._now()
await self.database.execute(
"""
INSERT INTO admin_user (
username, password, user_type, nickname, phone, email, status,
login_ip, login_time, created_at, updated_at, remark
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
username,
hash_password(password),
"SuperAdmin",
"Super Admin",
"",
"",
int(AdminUserStatusCode.NORMAL),
"",
"",
now,
now,
"seeded by application startup",
),
)
@staticmethod
def _now() -> str:
return datetime.now(UTC).strftime("%Y-%m-%d %H:%M:%S")

View File

@@ -0,0 +1,6 @@
from app.core.database import Database
class BaseRepository:
def __init__(self, database: Database) -> None:
self.database = database

View File

@@ -0,0 +1 @@
"""Security helpers."""

View File

@@ -0,0 +1,40 @@
import base64
import hashlib
import hmac
import secrets
ALGORITHM = "pbkdf2_sha256"
ITERATIONS = 390_000
def _b64encode(raw: bytes) -> str:
return base64.urlsafe_b64encode(raw).decode("ascii").rstrip("=")
def _b64decode(value: str) -> bytes:
padding = "=" * (-len(value) % 4)
return base64.urlsafe_b64decode(value + padding)
def hash_password(password: str, iterations: int = ITERATIONS) -> str:
salt = secrets.token_bytes(16)
digest = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, iterations)
return f"{ALGORITHM}${iterations}${_b64encode(salt)}${_b64encode(digest)}"
def verify_password(password: str, stored_hash: str) -> bool:
try:
algorithm, iterations, salt, expected = stored_hash.split("$", 3)
except ValueError:
return hmac.compare_digest(password, stored_hash)
if algorithm != ALGORITHM:
return False
digest = hashlib.pbkdf2_hmac(
"sha256",
password.encode("utf-8"),
_b64decode(salt),
int(iterations),
)
return hmac.compare_digest(_b64encode(digest), expected)

View File

@@ -0,0 +1 @@
"""Application constants."""

View File

@@ -0,0 +1,6 @@
from enum import IntEnum
class AdminCode(IntEnum):
DISABLED = 30001
FORBIDDEN = 30002

View File

@@ -0,0 +1 @@
"""Model constants."""

View File

@@ -0,0 +1 @@
"""Admin user constants."""

View File

@@ -0,0 +1,12 @@
from enum import IntEnum
class AdminUserStatusCode(IntEnum):
NORMAL = 1
DISABLE = 2
def is_normal(self) -> bool:
return self is AdminUserStatusCode.NORMAL
def is_disable(self) -> bool:
return self is AdminUserStatusCode.DISABLE

View File

@@ -0,0 +1,12 @@
from enum import IntEnum
class ResultCode(IntEnum):
SUCCESS = 0
ERROR = 1
JWT_EXPIRED = 10001
JWT_ERROR = 10002
OLD_PASSWORD_ERROR = 10003
ACCOUNT_DEACTIVATING = 20001
ACCOUNT_DEACTIVATED = 20002
ACCOUNT_CANNOT_DEACTIVATE = 20003

View File

@@ -0,0 +1 @@
"""Controller layer."""

View File

@@ -0,0 +1 @@
"""Admin controllers."""

View File

@@ -0,0 +1,32 @@
from fastapi import APIRouter, Depends, Request
from app.core.dependencies import get_login_service, get_refresh_service
from app.lib.jwt.token import JwtToken
from app.middleware.admin.refresh_admin_token_middleware import RefreshAdminTokenMiddleware
from app.request.admin.login_request import LoginRequest
from app.service.admin.login.login_service import LoginService
from app.service.admin.login.refresh_service import RefreshService
router = APIRouter(prefix="/admin/login", tags=["admin-login"])
@router.post("/login")
async def login(
payload: LoginRequest,
request: Request,
# FastAPI 的 Depends 类似 Hyperf 里的容器注入。
# 请求进入这个接口时,框架会先调用 get_login_service()
# 把组装好的 LoginService 传进 service 参数。
service: LoginService = Depends(get_login_service),
) -> dict:
return await service.handle(payload, request)
@router.post("/refresh")
async def refresh(
# 这里把 RefreshAdminTokenMiddleware 当成依赖使用。
# FastAPI 会先执行 refresh token 校验,通过后才进入 controller。
token: JwtToken = Depends(RefreshAdminTokenMiddleware()),
service: RefreshService = Depends(get_refresh_service),
) -> dict:
return await service.handle(token)

View File

@@ -0,0 +1,21 @@
from fastapi import APIRouter, Depends
from app.core.dependencies import get_current_user_service
from app.middleware.admin.admin_token_middleware import AdminTokenMiddleware
from app.middleware.admin.permission_middleware import PermissionMiddleware
from app.service.admin.profile.current_user_service import CurrentUserService
router = APIRouter(prefix="/admin/profile", tags=["admin-profile"])
@router.get(
"/current",
dependencies=[
Depends(AdminTokenMiddleware()),
Depends(PermissionMiddleware()),
],
)
async def current(
service: CurrentUserService = Depends(get_current_user_service),
) -> dict:
return await service.handle()

View File

@@ -0,0 +1 @@
"""Frontend API controllers."""

View File

@@ -0,0 +1,10 @@
from fastapi import APIRouter
from app.lib.response.admin_return import AdminReturn
router = APIRouter(prefix="/api", tags=["api"])
@router.get("/health")
async def health() -> dict:
return AdminReturn().success("success", {"status": "ok"})

1
app/core/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Core application wiring."""

39
app/core/config.py Normal file
View File

@@ -0,0 +1,39 @@
from functools import lru_cache
from pathlib import Path
from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_file=".env",
env_file_encoding="utf-8",
extra="ignore",
populate_by_name=True,
)
app_name: str = Field(default="py_server", alias="APP_NAME")
app_env: str = Field(default="local", alias="APP_ENV")
database_path: Path = Field(default=Path("storage/app.db"), alias="DATABASE_PATH")
admin_jwt_secret: str = Field(
default="dev_admin_secret_change_me",
alias="JWT_ADMIN_SECRET",
)
admin_jwt_ttl: int = Field(default=3600, alias="ADMIN_JWT_TTL")
admin_jwt_refresh_ttl: int = Field(default=7200, alias="ADMIN_JWT_REFRESH_TTL")
jwt_blacklist_ttl: int = Field(default=7201, alias="JWT_BLACKLIST_TTL")
api_jwt_secret: str = Field(default="dev_api_secret_change_me", alias="JWT_SECRET")
api_jwt_ttl: int = Field(default=3600, alias="JWT_TTL")
api_jwt_refresh_ttl: int = Field(default=7200, alias="JWT_REFRESH_TTL")
admin_seed_username: str = Field(default="admin", alias="ADMIN_SEED_USERNAME")
admin_seed_password: str = Field(default="admin", alias="ADMIN_SEED_PASSWORD")
cors_allow_origins: list[str] = Field(default=["*"], alias="CORS_ALLOW_ORIGINS")
@lru_cache
def get_settings() -> Settings:
return Settings()

87
app/core/database.py Normal file
View File

@@ -0,0 +1,87 @@
import asyncio
import sqlite3
from pathlib import Path
from typing import Any
class Database:
def __init__(self, path: Path) -> None:
self.path = path
async def initialize(self) -> None:
await asyncio.to_thread(self._initialize)
async def execute(self, sql: str, params: tuple[Any, ...] = ()) -> int:
return await asyncio.to_thread(self._execute, sql, params)
async def fetchone(
self,
sql: str,
params: tuple[Any, ...] = (),
) -> dict[str, Any] | None:
return await asyncio.to_thread(self._fetchone, sql, params)
async def fetchall(
self,
sql: str,
params: tuple[Any, ...] = (),
) -> list[dict[str, Any]]:
return await asyncio.to_thread(self._fetchall, sql, params)
def _connect(self) -> sqlite3.Connection:
self.path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(self.path)
conn.row_factory = sqlite3.Row
conn.execute("PRAGMA foreign_keys = ON")
return conn
def _initialize(self) -> None:
conn = self._connect()
try:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS admin_user (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL UNIQUE,
password TEXT NOT NULL,
user_type TEXT NOT NULL DEFAULT 'admin',
nickname TEXT NOT NULL DEFAULT '',
phone TEXT NOT NULL DEFAULT '',
email TEXT NOT NULL DEFAULT '',
status INTEGER NOT NULL DEFAULT 1,
login_ip TEXT NOT NULL DEFAULT '',
login_time TEXT NOT NULL DEFAULT '',
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
remark TEXT NOT NULL DEFAULT ''
)
"""
)
conn.commit()
finally:
conn.close()
def _execute(self, sql: str, params: tuple[Any, ...]) -> int:
conn = self._connect()
try:
cursor = conn.execute(sql, params)
conn.commit()
return int(cursor.lastrowid or 0)
finally:
conn.close()
def _fetchone(self, sql: str, params: tuple[Any, ...]) -> dict[str, Any] | None:
conn = self._connect()
try:
row = conn.execute(sql, params).fetchone()
return dict(row) if row else None
finally:
conn.close()
def _fetchall(self, sql: str, params: tuple[Any, ...]) -> list[dict[str, Any]]:
conn = self._connect()
try:
rows = conn.execute(sql, params).fetchall()
return [dict(row) for row in rows]
finally:
conn.close()

70
app/core/dependencies.py Normal file
View File

@@ -0,0 +1,70 @@
from functools import lru_cache
from app.common.repository.admin_user_repository import AdminUserRepository
from app.core.config import Settings, get_settings
from app.core.database import Database
from app.lib.jwt.blacklist import InMemoryTokenBlacklist
from app.lib.jwt.factory import JwtFactory
from app.lib.response.admin_return import AdminReturn
from app.service.admin.login.login_service import LoginService
from app.service.admin.login.refresh_service import RefreshService
from app.service.admin.profile.current_user_service import CurrentUserService
from app.service.base_token_service import BaseTokenService
# lru_cache 会缓存函数第一次创建出来的对象。
# 这里用它把 Database、JwtFactory、TokenService 等依赖做成应用级单例,
# 类似 Hyperf 从容器里反复 get 同一个共享服务。
@lru_cache
def get_database() -> Database:
return Database(get_settings().database_path)
@lru_cache
def get_token_blacklist() -> InMemoryTokenBlacklist:
return InMemoryTokenBlacklist()
@lru_cache
def get_jwt_factory() -> JwtFactory:
return JwtFactory(get_settings(), get_token_blacklist())
@lru_cache
def get_token_service() -> BaseTokenService:
return BaseTokenService(get_jwt_factory())
@lru_cache
def get_admin_return() -> AdminReturn:
return AdminReturn()
@lru_cache
def get_admin_user_repository() -> AdminUserRepository:
return AdminUserRepository(get_database())
def get_login_service() -> LoginService:
return LoginService(
get_admin_user_repository(),
get_token_service(),
get_admin_return(),
)
def get_refresh_service() -> RefreshService:
return RefreshService(get_token_service(), get_admin_return())
def get_current_user_service() -> CurrentUserService:
return CurrentUserService(get_admin_user_repository(), get_admin_return())
async def bootstrap_database(settings: Settings | None = None) -> None:
settings = settings or get_settings()
await get_database().initialize()
await get_admin_user_repository().ensure_seed_admin(
settings.admin_seed_username,
settings.admin_seed_password,
)

View File

@@ -0,0 +1 @@
"""Exception layer."""

View File

@@ -0,0 +1,14 @@
from app.constants.result_code import ResultCode
class ErrException(Exception):
def __init__(
self,
message: str = "failed",
code: int | ResultCode = ResultCode.ERROR,
data: dict | None = None,
) -> None:
super().__init__(message)
self.message = message
self.code = int(code)
self.data = data or {}

34
app/exception/handler.py Normal file
View File

@@ -0,0 +1,34 @@
from fastapi import FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from app.constants.result_code import ResultCode
from app.exception.err_exception import ErrException
from app.lib.response.admin_return import AdminReturn
async def err_exception_handler(request: Request, exc: ErrException) -> JSONResponse:
return JSONResponse(
status_code=200,
content=AdminReturn().error(exc.message, exc.code, exc.data),
)
async def validation_exception_handler(
request: Request,
exc: RequestValidationError,
) -> JSONResponse:
first_error = exc.errors()[0] if exc.errors() else {}
location = ".".join(str(item) for item in first_error.get("loc", []))
message = first_error.get("msg", "参数错误")
if location:
message = f"{location}: {message}"
return JSONResponse(
status_code=200,
content=AdminReturn().error(message, ResultCode.ERROR),
)
def register_exception_handlers(app: FastAPI) -> None:
app.add_exception_handler(ErrException, err_exception_handler)
app.add_exception_handler(RequestValidationError, validation_exception_handler)

1
app/lib/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Library layer."""

1
app/lib/jwt/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""JWT helpers."""

25
app/lib/jwt/blacklist.py Normal file
View File

@@ -0,0 +1,25 @@
import asyncio
import time
class InMemoryTokenBlacklist:
def __init__(self) -> None:
self._items: dict[str, float] = {}
self._lock = asyncio.Lock()
async def add(self, token: str, ttl: int) -> None:
now = time.time()
async with self._lock:
self._cleanup(now)
self._items[token] = now + ttl
async def has(self, token: str) -> bool:
now = time.time()
async with self._lock:
self._cleanup(now)
return token in self._items
def _cleanup(self, now: float) -> None:
expired = [token for token, expires_at in self._items.items() if expires_at <= now]
for token in expired:
self._items.pop(token, None)

10
app/lib/jwt/exceptions.py Normal file
View File

@@ -0,0 +1,10 @@
class JwtError(Exception):
pass
class JwtExpiredError(JwtError):
pass
class JwtInvalidError(JwtError):
pass

34
app/lib/jwt/factory.py Normal file
View File

@@ -0,0 +1,34 @@
from app.core.config import Settings
from app.lib.jwt.blacklist import InMemoryTokenBlacklist
from app.lib.jwt.jwt import Jwt, JwtSceneConfig
class JwtFactory:
def __init__(self, settings: Settings, blacklist: InMemoryTokenBlacklist) -> None:
self.settings = settings
self.blacklist = blacklist
def get(self, scene: str = "default") -> Jwt:
if scene == "admin":
return Jwt(
JwtSceneConfig(
secret=self.settings.admin_jwt_secret,
issuer=f"{self.settings.app_name}_admin",
ttl=self.settings.admin_jwt_ttl,
refresh_ttl=self.settings.admin_jwt_refresh_ttl,
blacklist_ttl=self.settings.jwt_blacklist_ttl,
),
self.blacklist,
)
scene_name = "default" if scene == "default" else scene
return Jwt(
JwtSceneConfig(
secret=self.settings.api_jwt_secret,
issuer=f"{self.settings.app_name}_{scene_name}",
ttl=self.settings.api_jwt_ttl,
refresh_ttl=self.settings.api_jwt_refresh_ttl,
blacklist_ttl=self.settings.jwt_blacklist_ttl,
),
self.blacklist,
)

151
app/lib/jwt/jwt.py Normal file
View File

@@ -0,0 +1,151 @@
import base64
import hashlib
import hmac
import json
import secrets
import time
from dataclasses import dataclass
from typing import Any
from app.lib.jwt.blacklist import InMemoryTokenBlacklist
from app.lib.jwt.exceptions import JwtExpiredError, JwtInvalidError
from app.lib.jwt.token import JwtToken
@dataclass(frozen=True, slots=True)
class JwtSceneConfig:
secret: str
issuer: str
ttl: int
refresh_ttl: int
blacklist_ttl: int
class Jwt:
def __init__(self, config: JwtSceneConfig, blacklist: InMemoryTokenBlacklist) -> None:
self.config = config
self.blacklist = blacklist
@property
def issuer(self) -> str:
return self.config.issuer
def builder_access_token(self, admin_id: str) -> str:
return self._build_token(admin_id, token_type="access", ttl=self.config.ttl)
def builder_refresh_token(self, admin_id: str) -> str:
return self._build_token(admin_id, token_type="refresh", ttl=self.config.refresh_ttl)
async def parser_access_token(self, raw_token: str) -> JwtToken:
token = self._parse(raw_token)
if token.is_refresh:
raise JwtInvalidError("Token is a refresh token")
self._validate_issuer(token)
return token
async def parser_refresh_token(self, raw_token: str) -> JwtToken:
token = self._parse(raw_token)
if not token.is_refresh:
raise JwtInvalidError("Token is not a refresh token")
self._validate_issuer(token)
return token
async def add_blacklist(self, token: JwtToken) -> None:
await self.blacklist.add(token.raw, self.config.blacklist_ttl)
async def has_blacklist(self, token: JwtToken) -> bool:
return await self.blacklist.has(token.raw)
def get_config(self, key: str, default: Any = None) -> Any:
return getattr(self.config, key, default)
def _build_token(self, admin_id: str, token_type: str, ttl: int) -> str:
now = int(time.time())
payload: dict[str, Any] = {
"iss": self.config.issuer,
"jti": str(admin_id),
"iat": now,
"nbf": now,
"exp": now + ttl,
"token_type": token_type,
"nonce": secrets.token_urlsafe(16),
}
if token_type == "refresh":
payload["sub"] = "refresh"
header = {"typ": "JWT", "alg": "HS256"}
signing_input = ".".join(
[
self._b64url_json(header),
self._b64url_json(payload),
]
)
signature = hmac.new(
self.config.secret.encode("utf-8"),
signing_input.encode("ascii"),
hashlib.sha256,
).digest()
return f"{signing_input}.{self._b64url_encode(signature)}"
def _parse(self, raw_token: str) -> JwtToken:
try:
header_b64, payload_b64, signature_b64 = raw_token.split(".", 2)
except ValueError as exc:
raise JwtInvalidError("Malformed token") from exc
signing_input = f"{header_b64}.{payload_b64}".encode("ascii")
expected = hmac.new(
self.config.secret.encode("utf-8"),
signing_input,
hashlib.sha256,
).digest()
if not hmac.compare_digest(self._b64url_encode(expected), signature_b64):
raise JwtInvalidError("Invalid signature")
try:
header = json.loads(self._b64url_decode(header_b64))
payload = json.loads(self._b64url_decode(payload_b64))
except (json.JSONDecodeError, UnicodeDecodeError) as exc:
raise JwtInvalidError("Invalid token payload") from exc
if header.get("alg") != "HS256":
raise JwtInvalidError("Unsupported algorithm")
if not isinstance(payload, dict):
raise JwtInvalidError("Invalid claims")
token = JwtToken(raw=raw_token, claims=payload)
self._validate_time(token)
return token
def _validate_time(self, token: JwtToken) -> None:
now = int(time.time())
exp = token.claims.get("exp")
nbf = token.claims.get("nbf")
iat = token.claims.get("iat")
if not isinstance(exp, int):
raise JwtInvalidError("Missing exp")
if exp <= now:
raise JwtExpiredError("Token expired")
if isinstance(nbf, int) and nbf > now:
raise JwtInvalidError("Token not active")
if isinstance(iat, int) and iat > now + 60:
raise JwtInvalidError("Token issued in future")
def _validate_issuer(self, token: JwtToken) -> None:
if token.claims.get("iss") != self.config.issuer:
raise JwtInvalidError("Invalid issuer")
@staticmethod
def _b64url_encode(raw: bytes) -> str:
return base64.urlsafe_b64encode(raw).decode("ascii").rstrip("=")
@classmethod
def _b64url_json(cls, value: dict[str, Any]) -> str:
raw = json.dumps(value, separators=(",", ":"), sort_keys=True).encode("utf-8")
return cls._b64url_encode(raw)
@staticmethod
def _b64url_decode(value: str) -> str:
padding = "=" * (-len(value) % 4)
return base64.urlsafe_b64decode(value + padding).decode("utf-8")

23
app/lib/jwt/token.py Normal file
View File

@@ -0,0 +1,23 @@
from dataclasses import dataclass
from typing import Any
@dataclass(frozen=True, slots=True)
class JwtToken:
raw: str
claims: dict[str, Any]
@property
def jwt_id(self) -> str:
return str(self.claims.get("jti", ""))
@property
def admin_id(self) -> int:
try:
return int(self.jwt_id)
except ValueError:
return 0
@property
def is_refresh(self) -> bool:
return self.claims.get("sub") == "refresh" or self.claims.get("token_type") == "refresh"

View File

@@ -0,0 +1 @@
"""Response helpers."""

View File

@@ -0,0 +1,6 @@
from app.lib.response.common_return import CommonReturn
class AdminReturn(CommonReturn):
def after_success(self, response: dict) -> dict:
return response

View File

@@ -0,0 +1,34 @@
from app.constants.result_code import ResultCode
class CommonReturn:
def success(
self,
message: str = "success",
data: dict | list | None = None,
code: int | ResultCode = ResultCode.SUCCESS,
) -> dict:
return self.after_success(
{
"code": int(code),
"message": message,
"data": data if data is not None else {},
}
)
def error(
self,
message: str = "failed",
code: int | ResultCode = ResultCode.ERROR,
data: dict | None = None,
) -> dict:
return self.after_success(
{
"code": int(code),
"message": message,
"data": data or {},
}
)
def after_success(self, response: dict) -> dict:
raise NotImplementedError

44
app/main.py Normal file
View File

@@ -0,0 +1,44 @@
from contextlib import asynccontextmanager
from typing import AsyncIterator
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.controller.admin.login_controller import router as admin_login_router
from app.controller.admin.profile_controller import router as admin_profile_router
from app.controller.api.health_controller import router as api_health_router
from app.core.config import get_settings
from app.core.dependencies import bootstrap_database
from app.exception.handler import register_exception_handlers
from app.lib.response.admin_return import AdminReturn
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
await bootstrap_database()
yield
def create_app() -> FastAPI:
settings = get_settings()
application = FastAPI(title=settings.app_name, lifespan=lifespan)
application.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_allow_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
register_exception_handlers(application)
application.include_router(admin_login_router)
application.include_router(admin_profile_router)
application.include_router(api_health_router)
@application.get("/")
async def index() -> dict:
return AdminReturn().success("success", {"name": settings.app_name})
return application
app = create_app()

View File

@@ -0,0 +1 @@
"""Middleware and route dependencies."""

View File

@@ -0,0 +1 @@
"""Admin middleware."""

View File

@@ -0,0 +1,16 @@
from fastapi import Request
from app.common.context import current_admin_id
from app.lib.jwt.jwt import Jwt
from app.lib.jwt.token import JwtToken
from app.middleware.token.abstract_token_middleware import AbstractTokenMiddleware
class AdminTokenMiddleware(AbstractTokenMiddleware):
async def parser_token(self, jwt: Jwt, raw_token: str) -> JwtToken:
return await jwt.parser_access_token(raw_token)
def set_context(self, request: Request, token: JwtToken) -> None:
admin_id = token.admin_id
current_admin_id.set(admin_id)
request.state.current_admin_id = admin_id

View File

@@ -0,0 +1,34 @@
from collections.abc import Iterable
from fastapi import Depends, Request
from app.common.context import current_admin_id
from app.common.repository.admin_user_repository import AdminUserRepository
from app.constants.admin_code import AdminCode
from app.constants.model.admin_user.admin_user_status_code import AdminUserStatusCode
from app.core.dependencies import get_admin_user_repository
from app.exception.err_exception import ErrException
class PermissionMiddleware:
def __init__(self, permissions: Iterable[str] | None = None) -> None:
self.permissions = tuple(permissions or ())
async def __call__(
self,
request: Request,
user_repository: AdminUserRepository = Depends(get_admin_user_repository),
) -> None:
admin_id = getattr(request.state, "current_admin_id", current_admin_id.get())
if admin_id <= 0:
raise ErrException("账户不存在")
admin_user = await user_repository.find_by_id(admin_id)
if admin_user is None:
raise ErrException("账户不存在")
if admin_user.status == AdminUserStatusCode.DISABLE:
raise ErrException("账号已禁用", AdminCode.DISABLED)
request.state.current_admin_user = admin_user
if self.permissions and admin_user.user_type != "SuperAdmin":
raise ErrException("暂无权限", AdminCode.FORBIDDEN)

View File

@@ -0,0 +1,13 @@
from fastapi import Request
from app.lib.jwt.jwt import Jwt
from app.lib.jwt.token import JwtToken
from app.middleware.token.abstract_token_middleware import AbstractTokenMiddleware
class RefreshAdminTokenMiddleware(AbstractTokenMiddleware):
async def parser_token(self, jwt: Jwt, raw_token: str) -> JwtToken:
return await jwt.parser_refresh_token(raw_token)
def set_context(self, request: Request, token: JwtToken) -> None:
return None

View File

@@ -0,0 +1 @@
"""Token middleware."""

View File

@@ -0,0 +1,72 @@
from fastapi import Depends, Request
from app.constants.result_code import ResultCode
from app.core.dependencies import get_token_service
from app.exception.err_exception import ErrException
from app.lib.jwt.exceptions import JwtError, JwtExpiredError
from app.lib.jwt.jwt import Jwt
from app.lib.jwt.token import JwtToken
from app.service.base_token_service import BaseTokenService
class AbstractTokenMiddleware:
scene = "admin"
async def __call__(
self,
request: Request,
token_service: BaseTokenService = Depends(get_token_service),
) -> JwtToken:
raw_token = self.get_token(request)
jwt = token_service.get_jwt(self.scene)
try:
token = await self.parser_token(jwt, raw_token)
await token_service.check_jwt(jwt, token)
self.check_issuer(jwt, token)
except JwtExpiredError as exc:
raise ErrException(
"token过期",
ResultCode.JWT_EXPIRED,
{"err_msg": str(exc)},
) from exc
except JwtError as exc:
raise ErrException(
"token错误",
ResultCode.JWT_ERROR,
{"err_msg": str(exc)},
) from exc
self.set_context(request, token)
request.state.token = token
return token
async def parser_token(self, jwt: Jwt, raw_token: str) -> JwtToken:
raise NotImplementedError
def set_context(self, request: Request, token: JwtToken) -> None:
raise NotImplementedError
@staticmethod
def check_issuer(jwt: Jwt, token: JwtToken) -> None:
if token.claims.get("iss") != jwt.issuer:
raise ErrException("token错误", ResultCode.JWT_ERROR)
@staticmethod
def get_token(request: Request) -> str:
authorization = request.headers.get("authorization")
if authorization:
scheme, _, token = authorization.partition(" ")
if scheme.lower() == "bearer" and token:
return token.strip()
return authorization.replace("Bearer ", "", 1).strip()
token_header = request.headers.get("token")
if token_header:
return token_header.strip()
token_query = request.query_params.get("token")
if token_query:
return token_query.strip()
raise ErrException("token缺失", ResultCode.JWT_ERROR)

1
app/model/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Model layer."""

53
app/model/admin_user.py Normal file
View File

@@ -0,0 +1,53 @@
from dataclasses import dataclass
from typing import Any
from app.common.security.password_hasher import verify_password
from app.constants.model.admin_user.admin_user_status_code import AdminUserStatusCode
@dataclass(slots=True)
class AdminUser:
id: int
username: str
password: str
user_type: str
nickname: str
phone: str
email: str
status: AdminUserStatusCode
login_ip: str
login_time: str
remark: str
@classmethod
def from_row(cls, row: dict[str, Any]) -> "AdminUser":
return cls(
id=int(row["id"]),
username=str(row["username"]),
password=str(row["password"]),
user_type=str(row["user_type"]),
nickname=str(row["nickname"]),
phone=str(row["phone"]),
email=str(row["email"]),
status=AdminUserStatusCode(int(row["status"])),
login_ip=str(row["login_ip"]),
login_time=str(row["login_time"]),
remark=str(row["remark"]),
)
def verify_password(self, password: str) -> bool:
return verify_password(password, self.password)
def to_public_dict(self) -> dict:
return {
"id": self.id,
"username": self.username,
"user_type": self.user_type,
"nickname": self.nickname,
"phone": self.phone,
"email": self.email,
"status": int(self.status),
"login_ip": self.login_ip,
"login_time": self.login_time,
"remark": self.remark,
}

1
app/request/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Request schemas."""

View File

@@ -0,0 +1 @@
"""Admin request schemas."""

View File

@@ -0,0 +1,6 @@
from pydantic import BaseModel, Field
class LoginRequest(BaseModel):
username: str = Field(min_length=1)
password: str = Field(min_length=1)

1
app/service/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Service layer."""

View File

@@ -0,0 +1 @@
"""Admin services."""

View File

@@ -0,0 +1,15 @@
from app.common.context import current_admin_id
from app.exception.err_exception import ErrException
from app.lib.response.admin_return import AdminReturn
class BaseAdminService:
def __init__(self, admin_return: AdminReturn) -> None:
self.admin_return = admin_return
@property
def admin_id(self) -> int:
admin_id = current_admin_id.get()
if admin_id <= 0:
raise ErrException("账户不存在")
return admin_id

View File

@@ -0,0 +1 @@
"""Admin login services."""

View File

@@ -0,0 +1,56 @@
import asyncio
from fastapi import Request
from app.common.repository.admin_user_repository import AdminUserRepository
from app.constants.model.admin_user.admin_user_status_code import AdminUserStatusCode
from app.exception.err_exception import ErrException
from app.lib.response.admin_return import AdminReturn
from app.request.admin.login_request import LoginRequest
from app.service.admin.base_admin_service import BaseAdminService
from app.service.base_token_service import BaseTokenService
class LoginService(BaseAdminService):
def __init__(
self,
user_repository: AdminUserRepository,
token_service: BaseTokenService,
admin_return: AdminReturn,
) -> None:
super().__init__(admin_return)
self.user_repository = user_repository
self.token_service = token_service
async def handle(self, payload: LoginRequest, request: Request) -> dict:
admin_info = await self.user_repository.find_by_username(payload.username)
if admin_info is None:
raise ErrException("后台管理员不存在")
password_valid = await asyncio.to_thread(
admin_info.verify_password,
payload.password,
)
if not password_valid:
raise ErrException("密码错误")
if admin_info.status == AdminUserStatusCode.DISABLE:
raise ErrException("用户已禁用")
await self.user_repository.record_login(admin_info.id, self._client_ip(request))
jwt = self.token_service.get_jwt("admin")
return self.admin_return.success(
"success",
{
"access_token": jwt.builder_access_token(str(admin_info.id)),
"refresh_token": jwt.builder_refresh_token(str(admin_info.id)),
"expire_at": int(jwt.get_config("ttl", 0)),
},
)
@staticmethod
def _client_ip(request: Request) -> str:
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
return forwarded.split(",", 1)[0].strip()
return request.client.host if request.client else ""

View File

@@ -0,0 +1,22 @@
from app.lib.jwt.token import JwtToken
from app.lib.response.admin_return import AdminReturn
from app.service.admin.base_admin_service import BaseAdminService
from app.service.base_token_service import BaseTokenService
class RefreshService(BaseAdminService):
def __init__(self, token_service: BaseTokenService, admin_return: AdminReturn) -> None:
super().__init__(admin_return)
self.token_service = token_service
async def handle(self, token: JwtToken) -> dict:
return self.admin_return.success("success", await self.refresh_token(token))
async def refresh_token(self, token: JwtToken) -> dict[str, int | str]:
jwt = self.token_service.get_jwt("admin")
await jwt.add_blacklist(token)
return {
"access_token": jwt.builder_access_token(token.jwt_id),
"refresh_token": jwt.builder_refresh_token(token.jwt_id),
"expire_at": int(jwt.get_config("ttl", 0)),
}

View File

@@ -0,0 +1 @@
"""Admin profile services."""

View File

@@ -0,0 +1,20 @@
from app.common.repository.admin_user_repository import AdminUserRepository
from app.exception.err_exception import ErrException
from app.lib.response.admin_return import AdminReturn
from app.service.admin.base_admin_service import BaseAdminService
class CurrentUserService(BaseAdminService):
def __init__(
self,
user_repository: AdminUserRepository,
admin_return: AdminReturn,
) -> None:
super().__init__(admin_return)
self.user_repository = user_repository
async def handle(self) -> dict:
admin_user = await self.user_repository.find_by_id(self.admin_id)
if admin_user is None:
raise ErrException("账户不存在")
return self.admin_return.success("success", admin_user.to_public_dict())

View File

@@ -0,0 +1,17 @@
from app.constants.result_code import ResultCode
from app.exception.err_exception import ErrException
from app.lib.jwt.factory import JwtFactory
from app.lib.jwt.jwt import Jwt
from app.lib.jwt.token import JwtToken
class BaseTokenService:
def __init__(self, jwt_factory: JwtFactory) -> None:
self.jwt_factory = jwt_factory
def get_jwt(self, scene: str = "admin") -> Jwt:
return self.jwt_factory.get(scene)
async def check_jwt(self, jwt: Jwt, token: JwtToken) -> None:
if await jwt.has_blacklist(token):
raise ErrException("token已过期", ResultCode.JWT_EXPIRED)

14
pyproject.toml Normal file
View File

@@ -0,0 +1,14 @@
[project]
name = "py-server"
version = "0.1.0"
description = "Async API server scaffold inspired by the Hyperf app layering."
requires-python = ">=3.14"
dependencies = [
"fastapi>=0.136.0",
"pydantic-settings>=2.14.0",
"uvicorn[standard]>=0.49.0",
]
[tool.pyright]
venvPath = "."
venv = ".venv"

1
tests/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Test package."""

View File

@@ -0,0 +1,88 @@
import os
import tempfile
import unittest
from pathlib import Path
DB_PATH = Path(tempfile.gettempdir()) / "py_server_admin_login_test.db"
DB_PATH.unlink(missing_ok=True)
os.environ["DATABASE_PATH"] = str(DB_PATH)
os.environ["JWT_ADMIN_SECRET"] = "test_admin_secret"
os.environ["ADMIN_SEED_USERNAME"] = "admin"
os.environ["ADMIN_SEED_PASSWORD"] = "admin"
from httpx import ASGITransport, AsyncClient
from app.core.dependencies import bootstrap_database
from app.main import app
class AdminLoginFlowTest(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self) -> None:
await bootstrap_database()
self.client = AsyncClient(
transport=ASGITransport(app=app),
base_url="http://testserver",
)
async def asyncTearDown(self) -> None:
await self.client.aclose()
async def test_login_access_and_refresh_flow(self) -> None:
login_response = await self.client.post(
"/admin/login/login",
json={"username": "admin", "password": "admin"},
)
login_payload = login_response.json()
self.assertEqual(login_payload["code"], 0)
self.assertIn("access_token", login_payload["data"])
self.assertIn("refresh_token", login_payload["data"])
self.assertEqual(login_payload["data"]["expire_at"], 3600)
access_token = login_payload["data"]["access_token"]
refresh_token = login_payload["data"]["refresh_token"]
current_response = await self.client.get(
"/admin/profile/current",
headers={"Authorization": f"Bearer {access_token}"},
)
current_payload = current_response.json()
self.assertEqual(current_payload["code"], 0)
self.assertEqual(current_payload["data"]["username"], "admin")
refresh_response = await self.client.post(
"/admin/login/refresh",
headers={"Authorization": f"Bearer {refresh_token}"},
)
refresh_payload = refresh_response.json()
self.assertEqual(refresh_payload["code"], 0)
self.assertNotEqual(refresh_payload["data"]["refresh_token"], refresh_token)
reused_response = await self.client.post(
"/admin/login/refresh",
headers={"Authorization": f"Bearer {refresh_token}"},
)
reused_payload = reused_response.json()
self.assertEqual(reused_payload["code"], 10001)
async def test_refresh_endpoint_rejects_access_token(self) -> None:
login_response = await self.client.post(
"/admin/login/login",
json={"username": "admin", "password": "admin"},
)
access_token = login_response.json()["data"]["access_token"]
refresh_response = await self.client.post(
"/admin/login/refresh",
headers={"Authorization": f"Bearer {access_token}"},
)
refresh_payload = refresh_response.json()
self.assertEqual(refresh_payload["code"], 10002)
if __name__ == "__main__":
unittest.main()