commit 5635da9ea5c0736e5dcefd12d73f34b8f1de0825 Author: ctexthuang Date: Fri Jun 5 17:10:30 2026 +0800 Initial FastAPI admin auth scaffold diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..53340e3 --- /dev/null +++ b/.env.example @@ -0,0 +1,11 @@ +APP_NAME=py_server +APP_ENV=local +DATABASE_PATH=storage/app.db + +JWT_ADMIN_SECRET=change_me_admin_secret +ADMIN_JWT_TTL=3600 +ADMIN_JWT_REFRESH_TTL=7200 +JWT_BLACKLIST_TTL=7201 + +ADMIN_SEED_USERNAME=admin +ADMIN_SEED_PASSWORD=admin diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..38a1804 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +.env +__pycache__/ +*.py[cod] +.pytest_cache/ +.ruff_cache/ +.idea/ +storage/*.db diff --git a/README.md b/README.md new file mode 100644 index 0000000..9af0bca --- /dev/null +++ b/README.md @@ -0,0 +1,145 @@ +# py_server + +这是一个 FastAPI 异步 API 项目骨架,分层方式参考了 Hyperf 项目里的 `server/app`。 + +## 分层对应 + +| Hyperf | FastAPI 当前项目 | 作用 | +| --- | --- | --- | +| `Controller/Admin` | `app/controller/admin` | 接收请求,调用 service | +| `Service/Admin` | `app/service/admin` | 业务逻辑 | +| `Common/Repository` | `app/common/repository` | 数据访问 | +| `Model` | `app/model` | 数据模型 | +| `Middleware/Admin` | `app/middleware/admin` | admin token 和权限校验 | +| `Lib/Jwt` | `app/lib/jwt` | JWT 签发、解析、黑名单 | +| `Request/Admin` | `app/request/admin` | 请求参数校验 | +| `Constants` | `app/constants` | 状态码和业务常量 | +| `Lib/Return` | `app/lib/response` | 统一返回结构 | + +> Python 里 `return` 是关键字,所以参考 Hyperf 的 `Lib/Return` 在这里命名为 `lib/response`。 + +## 启动 + +```bash +.venv/bin/uvicorn app.main:app --reload +``` + +默认启动时会自动创建 SQLite 开发库,并创建一个 admin 用户: + +- username: `admin` +- password: `admin` + +本地配置可以复制: + +```bash +cp .env.example .env +``` + +## Admin 登录接口 + +### 登录 + +```http +POST /admin/login/login +Content-Type: application/json + +{ + "username": "admin", + "password": "admin" +} +``` + +返回: + +```json +{ + "code": 0, + "message": "success", + "data": { + "access_token": "...", + "refresh_token": "...", + "expire_at": 3600 + } +} +``` + +### 刷新 token + +```http +POST /admin/login/refresh +Authorization: Bearer +``` + +刷新成功后会签发新的 `access_token` 和 `refresh_token`,并把旧的 `refresh_token` 加入黑名单,防止重复刷新。 + +### 当前 admin 用户 + +```http +GET /admin/profile/current +Authorization: Bearer +``` + +## JWT 逻辑 + +当前实现和参考 Hyperf 项目保持同样思路: + +1. 登录成功后同时签发 `access_token` 和 `refresh_token`。 +2. 普通 admin 接口只接受 `access_token`。 +3. `/admin/login/refresh` 只接受 `refresh_token`。 +4. refresh 成功后,把旧 refresh token 加入黑名单。 +5. token 会校验签名、过期时间、issuer 和 token 类型。 + +token 读取顺序和 Hyperf 中间件一致: + +1. `Authorization: Bearer ` +2. `token` header +3. query string: `?token=...` + +## FastAPI 依赖注入 + +Controller 里这段: + +```python +service: LoginService = Depends(get_login_service) +``` + +可以理解为 FastAPI 版的 Hyperf 容器注入。 + +请求进来时,FastAPI 会先调用 `get_login_service()`,把创建好的 `LoginService` 传给 `service` 参数。 + +`get_login_service()` 内部会继续组装: + +- `AdminUserRepository` +- `BaseTokenService` +- `AdminReturn` + +所以 controller 只负责接收请求和调用 service。 + +## `@lru_cache` 的作用 + +`app/core/dependencies.py` 里的: + +```python +@lru_cache +def get_database() -> Database: + return Database(get_settings().database_path) +``` + +表示第一次调用时创建对象,后面再次调用时直接复用第一次创建的对象。 + +在这个项目里,它的作用类似 Hyperf 容器里的共享服务/单例服务。比如 JWT 黑名单必须复用同一个对象,否则旧 refresh token 加入黑名单后,下一次请求就查不到了。 + +## 测试 + +```bash +.venv/bin/python -m compileall app tests +.venv/bin/python -m unittest tests.test_admin_login_flow +``` + +测试覆盖: + +- admin 登录成功 +- access token 访问当前用户成功 +- refresh token 换新 token 成功 +- 旧 refresh token 复用失败 +- access token 调用 refresh 接口失败 diff --git a/app/__init__.py b/app/__init__.py new file mode 100644 index 0000000..18b665e --- /dev/null +++ b/app/__init__.py @@ -0,0 +1 @@ +"""Application package.""" diff --git a/app/common/__init__.py b/app/common/__init__.py new file mode 100644 index 0000000..9248be8 --- /dev/null +++ b/app/common/__init__.py @@ -0,0 +1 @@ +"""Common helpers.""" diff --git a/app/common/context.py b/app/common/context.py new file mode 100644 index 0000000..e03ff4b --- /dev/null +++ b/app/common/context.py @@ -0,0 +1,3 @@ +from contextvars import ContextVar + +current_admin_id: ContextVar[int] = ContextVar("current_admin_id", default=0) diff --git a/app/common/repository/__init__.py b/app/common/repository/__init__.py new file mode 100644 index 0000000..4e2fa5e --- /dev/null +++ b/app/common/repository/__init__.py @@ -0,0 +1 @@ +"""Repository layer.""" diff --git a/app/common/repository/admin_user_repository.py b/app/common/repository/admin_user_repository.py new file mode 100644 index 0000000..77c683a --- /dev/null +++ b/app/common/repository/admin_user_repository.py @@ -0,0 +1,79 @@ +from datetime import UTC, datetime + +from app.common.repository.base_repository import BaseRepository +from app.common.security.password_hasher import hash_password +from app.constants.model.admin_user.admin_user_status_code import AdminUserStatusCode +from app.model.admin_user import AdminUser + + +class AdminUserRepository(BaseRepository): + async def find_by_username(self, username: str) -> AdminUser | None: + row = await self.database.fetchone( + """ + SELECT id, username, password, user_type, nickname, phone, email, + status, login_ip, login_time, remark + FROM admin_user + WHERE username = ? + LIMIT 1 + """, + (username,), + ) + return AdminUser.from_row(row) if row else None + + async def find_by_id(self, admin_id: int) -> AdminUser | None: + row = await self.database.fetchone( + """ + SELECT id, username, password, user_type, nickname, phone, email, + status, login_ip, login_time, remark + FROM admin_user + WHERE id = ? + LIMIT 1 + """, + (admin_id,), + ) + return AdminUser.from_row(row) if row else None + + async def record_login(self, admin_id: int, login_ip: str) -> None: + now = self._now() + await self.database.execute( + """ + UPDATE admin_user + SET login_ip = ?, login_time = ?, updated_at = ? + WHERE id = ? + """, + (login_ip, now, now, admin_id), + ) + + async def ensure_seed_admin(self, username: str, password: str) -> None: + exists = await self.find_by_username(username) + if exists is not None: + return + + now = self._now() + await self.database.execute( + """ + INSERT INTO admin_user ( + username, password, user_type, nickname, phone, email, status, + login_ip, login_time, created_at, updated_at, remark + ) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + ( + username, + hash_password(password), + "SuperAdmin", + "Super Admin", + "", + "", + int(AdminUserStatusCode.NORMAL), + "", + "", + now, + now, + "seeded by application startup", + ), + ) + + @staticmethod + def _now() -> str: + return datetime.now(UTC).strftime("%Y-%m-%d %H:%M:%S") diff --git a/app/common/repository/base_repository.py b/app/common/repository/base_repository.py new file mode 100644 index 0000000..12532f3 --- /dev/null +++ b/app/common/repository/base_repository.py @@ -0,0 +1,6 @@ +from app.core.database import Database + + +class BaseRepository: + def __init__(self, database: Database) -> None: + self.database = database diff --git a/app/common/security/__init__.py b/app/common/security/__init__.py new file mode 100644 index 0000000..5ed3b23 --- /dev/null +++ b/app/common/security/__init__.py @@ -0,0 +1 @@ +"""Security helpers.""" diff --git a/app/common/security/password_hasher.py b/app/common/security/password_hasher.py new file mode 100644 index 0000000..1c9d68f --- /dev/null +++ b/app/common/security/password_hasher.py @@ -0,0 +1,40 @@ +import base64 +import hashlib +import hmac +import secrets + +ALGORITHM = "pbkdf2_sha256" +ITERATIONS = 390_000 + + +def _b64encode(raw: bytes) -> str: + return base64.urlsafe_b64encode(raw).decode("ascii").rstrip("=") + + +def _b64decode(value: str) -> bytes: + padding = "=" * (-len(value) % 4) + return base64.urlsafe_b64decode(value + padding) + + +def hash_password(password: str, iterations: int = ITERATIONS) -> str: + salt = secrets.token_bytes(16) + digest = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, iterations) + return f"{ALGORITHM}${iterations}${_b64encode(salt)}${_b64encode(digest)}" + + +def verify_password(password: str, stored_hash: str) -> bool: + try: + algorithm, iterations, salt, expected = stored_hash.split("$", 3) + except ValueError: + return hmac.compare_digest(password, stored_hash) + + if algorithm != ALGORITHM: + return False + + digest = hashlib.pbkdf2_hmac( + "sha256", + password.encode("utf-8"), + _b64decode(salt), + int(iterations), + ) + return hmac.compare_digest(_b64encode(digest), expected) diff --git a/app/constants/__init__.py b/app/constants/__init__.py new file mode 100644 index 0000000..278d9b0 --- /dev/null +++ b/app/constants/__init__.py @@ -0,0 +1 @@ +"""Application constants.""" diff --git a/app/constants/admin_code.py b/app/constants/admin_code.py new file mode 100644 index 0000000..e388d51 --- /dev/null +++ b/app/constants/admin_code.py @@ -0,0 +1,6 @@ +from enum import IntEnum + + +class AdminCode(IntEnum): + DISABLED = 30001 + FORBIDDEN = 30002 diff --git a/app/constants/model/__init__.py b/app/constants/model/__init__.py new file mode 100644 index 0000000..a746493 --- /dev/null +++ b/app/constants/model/__init__.py @@ -0,0 +1 @@ +"""Model constants.""" diff --git a/app/constants/model/admin_user/__init__.py b/app/constants/model/admin_user/__init__.py new file mode 100644 index 0000000..e92d0c9 --- /dev/null +++ b/app/constants/model/admin_user/__init__.py @@ -0,0 +1 @@ +"""Admin user constants.""" diff --git a/app/constants/model/admin_user/admin_user_status_code.py b/app/constants/model/admin_user/admin_user_status_code.py new file mode 100644 index 0000000..d749aca --- /dev/null +++ b/app/constants/model/admin_user/admin_user_status_code.py @@ -0,0 +1,12 @@ +from enum import IntEnum + + +class AdminUserStatusCode(IntEnum): + NORMAL = 1 + DISABLE = 2 + + def is_normal(self) -> bool: + return self is AdminUserStatusCode.NORMAL + + def is_disable(self) -> bool: + return self is AdminUserStatusCode.DISABLE diff --git a/app/constants/result_code.py b/app/constants/result_code.py new file mode 100644 index 0000000..48c1d1c --- /dev/null +++ b/app/constants/result_code.py @@ -0,0 +1,12 @@ +from enum import IntEnum + + +class ResultCode(IntEnum): + SUCCESS = 0 + ERROR = 1 + JWT_EXPIRED = 10001 + JWT_ERROR = 10002 + OLD_PASSWORD_ERROR = 10003 + ACCOUNT_DEACTIVATING = 20001 + ACCOUNT_DEACTIVATED = 20002 + ACCOUNT_CANNOT_DEACTIVATE = 20003 diff --git a/app/controller/__init__.py b/app/controller/__init__.py new file mode 100644 index 0000000..79b930d --- /dev/null +++ b/app/controller/__init__.py @@ -0,0 +1 @@ +"""Controller layer.""" diff --git a/app/controller/admin/__init__.py b/app/controller/admin/__init__.py new file mode 100644 index 0000000..1aa52ec --- /dev/null +++ b/app/controller/admin/__init__.py @@ -0,0 +1 @@ +"""Admin controllers.""" diff --git a/app/controller/admin/login_controller.py b/app/controller/admin/login_controller.py new file mode 100644 index 0000000..338b8ee --- /dev/null +++ b/app/controller/admin/login_controller.py @@ -0,0 +1,32 @@ +from fastapi import APIRouter, Depends, Request + +from app.core.dependencies import get_login_service, get_refresh_service +from app.lib.jwt.token import JwtToken +from app.middleware.admin.refresh_admin_token_middleware import RefreshAdminTokenMiddleware +from app.request.admin.login_request import LoginRequest +from app.service.admin.login.login_service import LoginService +from app.service.admin.login.refresh_service import RefreshService + +router = APIRouter(prefix="/admin/login", tags=["admin-login"]) + + +@router.post("/login") +async def login( + payload: LoginRequest, + request: Request, + # FastAPI 的 Depends 类似 Hyperf 里的容器注入。 + # 请求进入这个接口时,框架会先调用 get_login_service(), + # 把组装好的 LoginService 传进 service 参数。 + service: LoginService = Depends(get_login_service), +) -> dict: + return await service.handle(payload, request) + + +@router.post("/refresh") +async def refresh( + # 这里把 RefreshAdminTokenMiddleware 当成依赖使用。 + # FastAPI 会先执行 refresh token 校验,通过后才进入 controller。 + token: JwtToken = Depends(RefreshAdminTokenMiddleware()), + service: RefreshService = Depends(get_refresh_service), +) -> dict: + return await service.handle(token) diff --git a/app/controller/admin/profile_controller.py b/app/controller/admin/profile_controller.py new file mode 100644 index 0000000..b8dd992 --- /dev/null +++ b/app/controller/admin/profile_controller.py @@ -0,0 +1,21 @@ +from fastapi import APIRouter, Depends + +from app.core.dependencies import get_current_user_service +from app.middleware.admin.admin_token_middleware import AdminTokenMiddleware +from app.middleware.admin.permission_middleware import PermissionMiddleware +from app.service.admin.profile.current_user_service import CurrentUserService + +router = APIRouter(prefix="/admin/profile", tags=["admin-profile"]) + + +@router.get( + "/current", + dependencies=[ + Depends(AdminTokenMiddleware()), + Depends(PermissionMiddleware()), + ], +) +async def current( + service: CurrentUserService = Depends(get_current_user_service), +) -> dict: + return await service.handle() diff --git a/app/controller/api/__init__.py b/app/controller/api/__init__.py new file mode 100644 index 0000000..4c0a0af --- /dev/null +++ b/app/controller/api/__init__.py @@ -0,0 +1 @@ +"""Frontend API controllers.""" diff --git a/app/controller/api/health_controller.py b/app/controller/api/health_controller.py new file mode 100644 index 0000000..a19f2cb --- /dev/null +++ b/app/controller/api/health_controller.py @@ -0,0 +1,10 @@ +from fastapi import APIRouter + +from app.lib.response.admin_return import AdminReturn + +router = APIRouter(prefix="/api", tags=["api"]) + + +@router.get("/health") +async def health() -> dict: + return AdminReturn().success("success", {"status": "ok"}) diff --git a/app/core/__init__.py b/app/core/__init__.py new file mode 100644 index 0000000..7ace94b --- /dev/null +++ b/app/core/__init__.py @@ -0,0 +1 @@ +"""Core application wiring.""" diff --git a/app/core/config.py b/app/core/config.py new file mode 100644 index 0000000..132214b --- /dev/null +++ b/app/core/config.py @@ -0,0 +1,39 @@ +from functools import lru_cache +from pathlib import Path + +from pydantic import Field +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + model_config = SettingsConfigDict( + env_file=".env", + env_file_encoding="utf-8", + extra="ignore", + populate_by_name=True, + ) + + app_name: str = Field(default="py_server", alias="APP_NAME") + app_env: str = Field(default="local", alias="APP_ENV") + database_path: Path = Field(default=Path("storage/app.db"), alias="DATABASE_PATH") + + admin_jwt_secret: str = Field( + default="dev_admin_secret_change_me", + alias="JWT_ADMIN_SECRET", + ) + admin_jwt_ttl: int = Field(default=3600, alias="ADMIN_JWT_TTL") + admin_jwt_refresh_ttl: int = Field(default=7200, alias="ADMIN_JWT_REFRESH_TTL") + jwt_blacklist_ttl: int = Field(default=7201, alias="JWT_BLACKLIST_TTL") + + api_jwt_secret: str = Field(default="dev_api_secret_change_me", alias="JWT_SECRET") + api_jwt_ttl: int = Field(default=3600, alias="JWT_TTL") + api_jwt_refresh_ttl: int = Field(default=7200, alias="JWT_REFRESH_TTL") + + admin_seed_username: str = Field(default="admin", alias="ADMIN_SEED_USERNAME") + admin_seed_password: str = Field(default="admin", alias="ADMIN_SEED_PASSWORD") + cors_allow_origins: list[str] = Field(default=["*"], alias="CORS_ALLOW_ORIGINS") + + +@lru_cache +def get_settings() -> Settings: + return Settings() diff --git a/app/core/database.py b/app/core/database.py new file mode 100644 index 0000000..b54b9fb --- /dev/null +++ b/app/core/database.py @@ -0,0 +1,87 @@ +import asyncio +import sqlite3 +from pathlib import Path +from typing import Any + + +class Database: + def __init__(self, path: Path) -> None: + self.path = path + + async def initialize(self) -> None: + await asyncio.to_thread(self._initialize) + + async def execute(self, sql: str, params: tuple[Any, ...] = ()) -> int: + return await asyncio.to_thread(self._execute, sql, params) + + async def fetchone( + self, + sql: str, + params: tuple[Any, ...] = (), + ) -> dict[str, Any] | None: + return await asyncio.to_thread(self._fetchone, sql, params) + + async def fetchall( + self, + sql: str, + params: tuple[Any, ...] = (), + ) -> list[dict[str, Any]]: + return await asyncio.to_thread(self._fetchall, sql, params) + + def _connect(self) -> sqlite3.Connection: + self.path.parent.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(self.path) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA foreign_keys = ON") + return conn + + def _initialize(self) -> None: + conn = self._connect() + try: + conn.execute( + """ + CREATE TABLE IF NOT EXISTS admin_user ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + username TEXT NOT NULL UNIQUE, + password TEXT NOT NULL, + user_type TEXT NOT NULL DEFAULT 'admin', + nickname TEXT NOT NULL DEFAULT '', + phone TEXT NOT NULL DEFAULT '', + email TEXT NOT NULL DEFAULT '', + status INTEGER NOT NULL DEFAULT 1, + login_ip TEXT NOT NULL DEFAULT '', + login_time TEXT NOT NULL DEFAULT '', + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + remark TEXT NOT NULL DEFAULT '' + ) + """ + ) + conn.commit() + finally: + conn.close() + + def _execute(self, sql: str, params: tuple[Any, ...]) -> int: + conn = self._connect() + try: + cursor = conn.execute(sql, params) + conn.commit() + return int(cursor.lastrowid or 0) + finally: + conn.close() + + def _fetchone(self, sql: str, params: tuple[Any, ...]) -> dict[str, Any] | None: + conn = self._connect() + try: + row = conn.execute(sql, params).fetchone() + return dict(row) if row else None + finally: + conn.close() + + def _fetchall(self, sql: str, params: tuple[Any, ...]) -> list[dict[str, Any]]: + conn = self._connect() + try: + rows = conn.execute(sql, params).fetchall() + return [dict(row) for row in rows] + finally: + conn.close() diff --git a/app/core/dependencies.py b/app/core/dependencies.py new file mode 100644 index 0000000..7f78233 --- /dev/null +++ b/app/core/dependencies.py @@ -0,0 +1,70 @@ +from functools import lru_cache + +from app.common.repository.admin_user_repository import AdminUserRepository +from app.core.config import Settings, get_settings +from app.core.database import Database +from app.lib.jwt.blacklist import InMemoryTokenBlacklist +from app.lib.jwt.factory import JwtFactory +from app.lib.response.admin_return import AdminReturn +from app.service.admin.login.login_service import LoginService +from app.service.admin.login.refresh_service import RefreshService +from app.service.admin.profile.current_user_service import CurrentUserService +from app.service.base_token_service import BaseTokenService + + +# lru_cache 会缓存函数第一次创建出来的对象。 +# 这里用它把 Database、JwtFactory、TokenService 等依赖做成应用级单例, +# 类似 Hyperf 从容器里反复 get 同一个共享服务。 +@lru_cache +def get_database() -> Database: + return Database(get_settings().database_path) + + +@lru_cache +def get_token_blacklist() -> InMemoryTokenBlacklist: + return InMemoryTokenBlacklist() + + +@lru_cache +def get_jwt_factory() -> JwtFactory: + return JwtFactory(get_settings(), get_token_blacklist()) + + +@lru_cache +def get_token_service() -> BaseTokenService: + return BaseTokenService(get_jwt_factory()) + + +@lru_cache +def get_admin_return() -> AdminReturn: + return AdminReturn() + + +@lru_cache +def get_admin_user_repository() -> AdminUserRepository: + return AdminUserRepository(get_database()) + + +def get_login_service() -> LoginService: + return LoginService( + get_admin_user_repository(), + get_token_service(), + get_admin_return(), + ) + + +def get_refresh_service() -> RefreshService: + return RefreshService(get_token_service(), get_admin_return()) + + +def get_current_user_service() -> CurrentUserService: + return CurrentUserService(get_admin_user_repository(), get_admin_return()) + + +async def bootstrap_database(settings: Settings | None = None) -> None: + settings = settings or get_settings() + await get_database().initialize() + await get_admin_user_repository().ensure_seed_admin( + settings.admin_seed_username, + settings.admin_seed_password, + ) diff --git a/app/exception/__init__.py b/app/exception/__init__.py new file mode 100644 index 0000000..c2c901a --- /dev/null +++ b/app/exception/__init__.py @@ -0,0 +1 @@ +"""Exception layer.""" diff --git a/app/exception/err_exception.py b/app/exception/err_exception.py new file mode 100644 index 0000000..89d023f --- /dev/null +++ b/app/exception/err_exception.py @@ -0,0 +1,14 @@ +from app.constants.result_code import ResultCode + + +class ErrException(Exception): + def __init__( + self, + message: str = "failed", + code: int | ResultCode = ResultCode.ERROR, + data: dict | None = None, + ) -> None: + super().__init__(message) + self.message = message + self.code = int(code) + self.data = data or {} diff --git a/app/exception/handler.py b/app/exception/handler.py new file mode 100644 index 0000000..4f6c1de --- /dev/null +++ b/app/exception/handler.py @@ -0,0 +1,34 @@ +from fastapi import FastAPI, Request +from fastapi.exceptions import RequestValidationError +from fastapi.responses import JSONResponse + +from app.constants.result_code import ResultCode +from app.exception.err_exception import ErrException +from app.lib.response.admin_return import AdminReturn + + +async def err_exception_handler(request: Request, exc: ErrException) -> JSONResponse: + return JSONResponse( + status_code=200, + content=AdminReturn().error(exc.message, exc.code, exc.data), + ) + + +async def validation_exception_handler( + request: Request, + exc: RequestValidationError, +) -> JSONResponse: + first_error = exc.errors()[0] if exc.errors() else {} + location = ".".join(str(item) for item in first_error.get("loc", [])) + message = first_error.get("msg", "参数错误") + if location: + message = f"{location}: {message}" + return JSONResponse( + status_code=200, + content=AdminReturn().error(message, ResultCode.ERROR), + ) + + +def register_exception_handlers(app: FastAPI) -> None: + app.add_exception_handler(ErrException, err_exception_handler) + app.add_exception_handler(RequestValidationError, validation_exception_handler) diff --git a/app/lib/__init__.py b/app/lib/__init__.py new file mode 100644 index 0000000..3340949 --- /dev/null +++ b/app/lib/__init__.py @@ -0,0 +1 @@ +"""Library layer.""" diff --git a/app/lib/jwt/__init__.py b/app/lib/jwt/__init__.py new file mode 100644 index 0000000..11a47b9 --- /dev/null +++ b/app/lib/jwt/__init__.py @@ -0,0 +1 @@ +"""JWT helpers.""" diff --git a/app/lib/jwt/blacklist.py b/app/lib/jwt/blacklist.py new file mode 100644 index 0000000..44bef09 --- /dev/null +++ b/app/lib/jwt/blacklist.py @@ -0,0 +1,25 @@ +import asyncio +import time + + +class InMemoryTokenBlacklist: + def __init__(self) -> None: + self._items: dict[str, float] = {} + self._lock = asyncio.Lock() + + async def add(self, token: str, ttl: int) -> None: + now = time.time() + async with self._lock: + self._cleanup(now) + self._items[token] = now + ttl + + async def has(self, token: str) -> bool: + now = time.time() + async with self._lock: + self._cleanup(now) + return token in self._items + + def _cleanup(self, now: float) -> None: + expired = [token for token, expires_at in self._items.items() if expires_at <= now] + for token in expired: + self._items.pop(token, None) diff --git a/app/lib/jwt/exceptions.py b/app/lib/jwt/exceptions.py new file mode 100644 index 0000000..306c102 --- /dev/null +++ b/app/lib/jwt/exceptions.py @@ -0,0 +1,10 @@ +class JwtError(Exception): + pass + + +class JwtExpiredError(JwtError): + pass + + +class JwtInvalidError(JwtError): + pass diff --git a/app/lib/jwt/factory.py b/app/lib/jwt/factory.py new file mode 100644 index 0000000..d17cc38 --- /dev/null +++ b/app/lib/jwt/factory.py @@ -0,0 +1,34 @@ +from app.core.config import Settings +from app.lib.jwt.blacklist import InMemoryTokenBlacklist +from app.lib.jwt.jwt import Jwt, JwtSceneConfig + + +class JwtFactory: + def __init__(self, settings: Settings, blacklist: InMemoryTokenBlacklist) -> None: + self.settings = settings + self.blacklist = blacklist + + def get(self, scene: str = "default") -> Jwt: + if scene == "admin": + return Jwt( + JwtSceneConfig( + secret=self.settings.admin_jwt_secret, + issuer=f"{self.settings.app_name}_admin", + ttl=self.settings.admin_jwt_ttl, + refresh_ttl=self.settings.admin_jwt_refresh_ttl, + blacklist_ttl=self.settings.jwt_blacklist_ttl, + ), + self.blacklist, + ) + + scene_name = "default" if scene == "default" else scene + return Jwt( + JwtSceneConfig( + secret=self.settings.api_jwt_secret, + issuer=f"{self.settings.app_name}_{scene_name}", + ttl=self.settings.api_jwt_ttl, + refresh_ttl=self.settings.api_jwt_refresh_ttl, + blacklist_ttl=self.settings.jwt_blacklist_ttl, + ), + self.blacklist, + ) diff --git a/app/lib/jwt/jwt.py b/app/lib/jwt/jwt.py new file mode 100644 index 0000000..3d6545d --- /dev/null +++ b/app/lib/jwt/jwt.py @@ -0,0 +1,151 @@ +import base64 +import hashlib +import hmac +import json +import secrets +import time +from dataclasses import dataclass +from typing import Any + +from app.lib.jwt.blacklist import InMemoryTokenBlacklist +from app.lib.jwt.exceptions import JwtExpiredError, JwtInvalidError +from app.lib.jwt.token import JwtToken + + +@dataclass(frozen=True, slots=True) +class JwtSceneConfig: + secret: str + issuer: str + ttl: int + refresh_ttl: int + blacklist_ttl: int + + +class Jwt: + def __init__(self, config: JwtSceneConfig, blacklist: InMemoryTokenBlacklist) -> None: + self.config = config + self.blacklist = blacklist + + @property + def issuer(self) -> str: + return self.config.issuer + + def builder_access_token(self, admin_id: str) -> str: + return self._build_token(admin_id, token_type="access", ttl=self.config.ttl) + + def builder_refresh_token(self, admin_id: str) -> str: + return self._build_token(admin_id, token_type="refresh", ttl=self.config.refresh_ttl) + + async def parser_access_token(self, raw_token: str) -> JwtToken: + token = self._parse(raw_token) + if token.is_refresh: + raise JwtInvalidError("Token is a refresh token") + self._validate_issuer(token) + return token + + async def parser_refresh_token(self, raw_token: str) -> JwtToken: + token = self._parse(raw_token) + if not token.is_refresh: + raise JwtInvalidError("Token is not a refresh token") + self._validate_issuer(token) + return token + + async def add_blacklist(self, token: JwtToken) -> None: + await self.blacklist.add(token.raw, self.config.blacklist_ttl) + + async def has_blacklist(self, token: JwtToken) -> bool: + return await self.blacklist.has(token.raw) + + def get_config(self, key: str, default: Any = None) -> Any: + return getattr(self.config, key, default) + + def _build_token(self, admin_id: str, token_type: str, ttl: int) -> str: + now = int(time.time()) + payload: dict[str, Any] = { + "iss": self.config.issuer, + "jti": str(admin_id), + "iat": now, + "nbf": now, + "exp": now + ttl, + "token_type": token_type, + "nonce": secrets.token_urlsafe(16), + } + if token_type == "refresh": + payload["sub"] = "refresh" + + header = {"typ": "JWT", "alg": "HS256"} + signing_input = ".".join( + [ + self._b64url_json(header), + self._b64url_json(payload), + ] + ) + signature = hmac.new( + self.config.secret.encode("utf-8"), + signing_input.encode("ascii"), + hashlib.sha256, + ).digest() + return f"{signing_input}.{self._b64url_encode(signature)}" + + def _parse(self, raw_token: str) -> JwtToken: + try: + header_b64, payload_b64, signature_b64 = raw_token.split(".", 2) + except ValueError as exc: + raise JwtInvalidError("Malformed token") from exc + + signing_input = f"{header_b64}.{payload_b64}".encode("ascii") + expected = hmac.new( + self.config.secret.encode("utf-8"), + signing_input, + hashlib.sha256, + ).digest() + if not hmac.compare_digest(self._b64url_encode(expected), signature_b64): + raise JwtInvalidError("Invalid signature") + + try: + header = json.loads(self._b64url_decode(header_b64)) + payload = json.loads(self._b64url_decode(payload_b64)) + except (json.JSONDecodeError, UnicodeDecodeError) as exc: + raise JwtInvalidError("Invalid token payload") from exc + + if header.get("alg") != "HS256": + raise JwtInvalidError("Unsupported algorithm") + if not isinstance(payload, dict): + raise JwtInvalidError("Invalid claims") + + token = JwtToken(raw=raw_token, claims=payload) + self._validate_time(token) + return token + + def _validate_time(self, token: JwtToken) -> None: + now = int(time.time()) + exp = token.claims.get("exp") + nbf = token.claims.get("nbf") + iat = token.claims.get("iat") + + if not isinstance(exp, int): + raise JwtInvalidError("Missing exp") + if exp <= now: + raise JwtExpiredError("Token expired") + if isinstance(nbf, int) and nbf > now: + raise JwtInvalidError("Token not active") + if isinstance(iat, int) and iat > now + 60: + raise JwtInvalidError("Token issued in future") + + def _validate_issuer(self, token: JwtToken) -> None: + if token.claims.get("iss") != self.config.issuer: + raise JwtInvalidError("Invalid issuer") + + @staticmethod + def _b64url_encode(raw: bytes) -> str: + return base64.urlsafe_b64encode(raw).decode("ascii").rstrip("=") + + @classmethod + def _b64url_json(cls, value: dict[str, Any]) -> str: + raw = json.dumps(value, separators=(",", ":"), sort_keys=True).encode("utf-8") + return cls._b64url_encode(raw) + + @staticmethod + def _b64url_decode(value: str) -> str: + padding = "=" * (-len(value) % 4) + return base64.urlsafe_b64decode(value + padding).decode("utf-8") diff --git a/app/lib/jwt/token.py b/app/lib/jwt/token.py new file mode 100644 index 0000000..8e10ded --- /dev/null +++ b/app/lib/jwt/token.py @@ -0,0 +1,23 @@ +from dataclasses import dataclass +from typing import Any + + +@dataclass(frozen=True, slots=True) +class JwtToken: + raw: str + claims: dict[str, Any] + + @property + def jwt_id(self) -> str: + return str(self.claims.get("jti", "")) + + @property + def admin_id(self) -> int: + try: + return int(self.jwt_id) + except ValueError: + return 0 + + @property + def is_refresh(self) -> bool: + return self.claims.get("sub") == "refresh" or self.claims.get("token_type") == "refresh" diff --git a/app/lib/response/__init__.py b/app/lib/response/__init__.py new file mode 100644 index 0000000..410805b --- /dev/null +++ b/app/lib/response/__init__.py @@ -0,0 +1 @@ +"""Response helpers.""" diff --git a/app/lib/response/admin_return.py b/app/lib/response/admin_return.py new file mode 100644 index 0000000..273f0aa --- /dev/null +++ b/app/lib/response/admin_return.py @@ -0,0 +1,6 @@ +from app.lib.response.common_return import CommonReturn + + +class AdminReturn(CommonReturn): + def after_success(self, response: dict) -> dict: + return response diff --git a/app/lib/response/common_return.py b/app/lib/response/common_return.py new file mode 100644 index 0000000..4f6af3b --- /dev/null +++ b/app/lib/response/common_return.py @@ -0,0 +1,34 @@ +from app.constants.result_code import ResultCode + + +class CommonReturn: + def success( + self, + message: str = "success", + data: dict | list | None = None, + code: int | ResultCode = ResultCode.SUCCESS, + ) -> dict: + return self.after_success( + { + "code": int(code), + "message": message, + "data": data if data is not None else {}, + } + ) + + def error( + self, + message: str = "failed", + code: int | ResultCode = ResultCode.ERROR, + data: dict | None = None, + ) -> dict: + return self.after_success( + { + "code": int(code), + "message": message, + "data": data or {}, + } + ) + + def after_success(self, response: dict) -> dict: + raise NotImplementedError diff --git a/app/main.py b/app/main.py new file mode 100644 index 0000000..628d6ae --- /dev/null +++ b/app/main.py @@ -0,0 +1,44 @@ +from contextlib import asynccontextmanager +from typing import AsyncIterator + +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from app.controller.admin.login_controller import router as admin_login_router +from app.controller.admin.profile_controller import router as admin_profile_router +from app.controller.api.health_controller import router as api_health_router +from app.core.config import get_settings +from app.core.dependencies import bootstrap_database +from app.exception.handler import register_exception_handlers +from app.lib.response.admin_return import AdminReturn + + +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncIterator[None]: + await bootstrap_database() + yield + + +def create_app() -> FastAPI: + settings = get_settings() + application = FastAPI(title=settings.app_name, lifespan=lifespan) + application.add_middleware( + CORSMiddleware, + allow_origins=settings.cors_allow_origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + register_exception_handlers(application) + application.include_router(admin_login_router) + application.include_router(admin_profile_router) + application.include_router(api_health_router) + + @application.get("/") + async def index() -> dict: + return AdminReturn().success("success", {"name": settings.app_name}) + + return application + + +app = create_app() diff --git a/app/middleware/__init__.py b/app/middleware/__init__.py new file mode 100644 index 0000000..803baa0 --- /dev/null +++ b/app/middleware/__init__.py @@ -0,0 +1 @@ +"""Middleware and route dependencies.""" diff --git a/app/middleware/admin/__init__.py b/app/middleware/admin/__init__.py new file mode 100644 index 0000000..2bb41d8 --- /dev/null +++ b/app/middleware/admin/__init__.py @@ -0,0 +1 @@ +"""Admin middleware.""" diff --git a/app/middleware/admin/admin_token_middleware.py b/app/middleware/admin/admin_token_middleware.py new file mode 100644 index 0000000..a56fc86 --- /dev/null +++ b/app/middleware/admin/admin_token_middleware.py @@ -0,0 +1,16 @@ +from fastapi import Request + +from app.common.context import current_admin_id +from app.lib.jwt.jwt import Jwt +from app.lib.jwt.token import JwtToken +from app.middleware.token.abstract_token_middleware import AbstractTokenMiddleware + + +class AdminTokenMiddleware(AbstractTokenMiddleware): + async def parser_token(self, jwt: Jwt, raw_token: str) -> JwtToken: + return await jwt.parser_access_token(raw_token) + + def set_context(self, request: Request, token: JwtToken) -> None: + admin_id = token.admin_id + current_admin_id.set(admin_id) + request.state.current_admin_id = admin_id diff --git a/app/middleware/admin/permission_middleware.py b/app/middleware/admin/permission_middleware.py new file mode 100644 index 0000000..62aeb21 --- /dev/null +++ b/app/middleware/admin/permission_middleware.py @@ -0,0 +1,34 @@ +from collections.abc import Iterable + +from fastapi import Depends, Request + +from app.common.context import current_admin_id +from app.common.repository.admin_user_repository import AdminUserRepository +from app.constants.admin_code import AdminCode +from app.constants.model.admin_user.admin_user_status_code import AdminUserStatusCode +from app.core.dependencies import get_admin_user_repository +from app.exception.err_exception import ErrException + + +class PermissionMiddleware: + def __init__(self, permissions: Iterable[str] | None = None) -> None: + self.permissions = tuple(permissions or ()) + + async def __call__( + self, + request: Request, + user_repository: AdminUserRepository = Depends(get_admin_user_repository), + ) -> None: + admin_id = getattr(request.state, "current_admin_id", current_admin_id.get()) + if admin_id <= 0: + raise ErrException("账户不存在") + + admin_user = await user_repository.find_by_id(admin_id) + if admin_user is None: + raise ErrException("账户不存在") + if admin_user.status == AdminUserStatusCode.DISABLE: + raise ErrException("账号已禁用", AdminCode.DISABLED) + + request.state.current_admin_user = admin_user + if self.permissions and admin_user.user_type != "SuperAdmin": + raise ErrException("暂无权限", AdminCode.FORBIDDEN) diff --git a/app/middleware/admin/refresh_admin_token_middleware.py b/app/middleware/admin/refresh_admin_token_middleware.py new file mode 100644 index 0000000..de55bc2 --- /dev/null +++ b/app/middleware/admin/refresh_admin_token_middleware.py @@ -0,0 +1,13 @@ +from fastapi import Request + +from app.lib.jwt.jwt import Jwt +from app.lib.jwt.token import JwtToken +from app.middleware.token.abstract_token_middleware import AbstractTokenMiddleware + + +class RefreshAdminTokenMiddleware(AbstractTokenMiddleware): + async def parser_token(self, jwt: Jwt, raw_token: str) -> JwtToken: + return await jwt.parser_refresh_token(raw_token) + + def set_context(self, request: Request, token: JwtToken) -> None: + return None diff --git a/app/middleware/token/__init__.py b/app/middleware/token/__init__.py new file mode 100644 index 0000000..2e2b235 --- /dev/null +++ b/app/middleware/token/__init__.py @@ -0,0 +1 @@ +"""Token middleware.""" diff --git a/app/middleware/token/abstract_token_middleware.py b/app/middleware/token/abstract_token_middleware.py new file mode 100644 index 0000000..720f650 --- /dev/null +++ b/app/middleware/token/abstract_token_middleware.py @@ -0,0 +1,72 @@ +from fastapi import Depends, Request + +from app.constants.result_code import ResultCode +from app.core.dependencies import get_token_service +from app.exception.err_exception import ErrException +from app.lib.jwt.exceptions import JwtError, JwtExpiredError +from app.lib.jwt.jwt import Jwt +from app.lib.jwt.token import JwtToken +from app.service.base_token_service import BaseTokenService + + +class AbstractTokenMiddleware: + scene = "admin" + + async def __call__( + self, + request: Request, + token_service: BaseTokenService = Depends(get_token_service), + ) -> JwtToken: + raw_token = self.get_token(request) + jwt = token_service.get_jwt(self.scene) + + try: + token = await self.parser_token(jwt, raw_token) + await token_service.check_jwt(jwt, token) + self.check_issuer(jwt, token) + except JwtExpiredError as exc: + raise ErrException( + "token过期", + ResultCode.JWT_EXPIRED, + {"err_msg": str(exc)}, + ) from exc + except JwtError as exc: + raise ErrException( + "token错误", + ResultCode.JWT_ERROR, + {"err_msg": str(exc)}, + ) from exc + + self.set_context(request, token) + request.state.token = token + return token + + async def parser_token(self, jwt: Jwt, raw_token: str) -> JwtToken: + raise NotImplementedError + + def set_context(self, request: Request, token: JwtToken) -> None: + raise NotImplementedError + + @staticmethod + def check_issuer(jwt: Jwt, token: JwtToken) -> None: + if token.claims.get("iss") != jwt.issuer: + raise ErrException("token错误", ResultCode.JWT_ERROR) + + @staticmethod + def get_token(request: Request) -> str: + authorization = request.headers.get("authorization") + if authorization: + scheme, _, token = authorization.partition(" ") + if scheme.lower() == "bearer" and token: + return token.strip() + return authorization.replace("Bearer ", "", 1).strip() + + token_header = request.headers.get("token") + if token_header: + return token_header.strip() + + token_query = request.query_params.get("token") + if token_query: + return token_query.strip() + + raise ErrException("token缺失", ResultCode.JWT_ERROR) diff --git a/app/model/__init__.py b/app/model/__init__.py new file mode 100644 index 0000000..3968465 --- /dev/null +++ b/app/model/__init__.py @@ -0,0 +1 @@ +"""Model layer.""" diff --git a/app/model/admin_user.py b/app/model/admin_user.py new file mode 100644 index 0000000..56f1334 --- /dev/null +++ b/app/model/admin_user.py @@ -0,0 +1,53 @@ +from dataclasses import dataclass +from typing import Any + +from app.common.security.password_hasher import verify_password +from app.constants.model.admin_user.admin_user_status_code import AdminUserStatusCode + + +@dataclass(slots=True) +class AdminUser: + id: int + username: str + password: str + user_type: str + nickname: str + phone: str + email: str + status: AdminUserStatusCode + login_ip: str + login_time: str + remark: str + + @classmethod + def from_row(cls, row: dict[str, Any]) -> "AdminUser": + return cls( + id=int(row["id"]), + username=str(row["username"]), + password=str(row["password"]), + user_type=str(row["user_type"]), + nickname=str(row["nickname"]), + phone=str(row["phone"]), + email=str(row["email"]), + status=AdminUserStatusCode(int(row["status"])), + login_ip=str(row["login_ip"]), + login_time=str(row["login_time"]), + remark=str(row["remark"]), + ) + + def verify_password(self, password: str) -> bool: + return verify_password(password, self.password) + + def to_public_dict(self) -> dict: + return { + "id": self.id, + "username": self.username, + "user_type": self.user_type, + "nickname": self.nickname, + "phone": self.phone, + "email": self.email, + "status": int(self.status), + "login_ip": self.login_ip, + "login_time": self.login_time, + "remark": self.remark, + } diff --git a/app/request/__init__.py b/app/request/__init__.py new file mode 100644 index 0000000..6937b80 --- /dev/null +++ b/app/request/__init__.py @@ -0,0 +1 @@ +"""Request schemas.""" diff --git a/app/request/admin/__init__.py b/app/request/admin/__init__.py new file mode 100644 index 0000000..08fcc12 --- /dev/null +++ b/app/request/admin/__init__.py @@ -0,0 +1 @@ +"""Admin request schemas.""" diff --git a/app/request/admin/login_request.py b/app/request/admin/login_request.py new file mode 100644 index 0000000..4b04f98 --- /dev/null +++ b/app/request/admin/login_request.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel, Field + + +class LoginRequest(BaseModel): + username: str = Field(min_length=1) + password: str = Field(min_length=1) diff --git a/app/service/__init__.py b/app/service/__init__.py new file mode 100644 index 0000000..02dea84 --- /dev/null +++ b/app/service/__init__.py @@ -0,0 +1 @@ +"""Service layer.""" diff --git a/app/service/admin/__init__.py b/app/service/admin/__init__.py new file mode 100644 index 0000000..adc44c3 --- /dev/null +++ b/app/service/admin/__init__.py @@ -0,0 +1 @@ +"""Admin services.""" diff --git a/app/service/admin/base_admin_service.py b/app/service/admin/base_admin_service.py new file mode 100644 index 0000000..106cf79 --- /dev/null +++ b/app/service/admin/base_admin_service.py @@ -0,0 +1,15 @@ +from app.common.context import current_admin_id +from app.exception.err_exception import ErrException +from app.lib.response.admin_return import AdminReturn + + +class BaseAdminService: + def __init__(self, admin_return: AdminReturn) -> None: + self.admin_return = admin_return + + @property + def admin_id(self) -> int: + admin_id = current_admin_id.get() + if admin_id <= 0: + raise ErrException("账户不存在") + return admin_id diff --git a/app/service/admin/login/__init__.py b/app/service/admin/login/__init__.py new file mode 100644 index 0000000..2498382 --- /dev/null +++ b/app/service/admin/login/__init__.py @@ -0,0 +1 @@ +"""Admin login services.""" diff --git a/app/service/admin/login/login_service.py b/app/service/admin/login/login_service.py new file mode 100644 index 0000000..b3832d1 --- /dev/null +++ b/app/service/admin/login/login_service.py @@ -0,0 +1,56 @@ +import asyncio + +from fastapi import Request + +from app.common.repository.admin_user_repository import AdminUserRepository +from app.constants.model.admin_user.admin_user_status_code import AdminUserStatusCode +from app.exception.err_exception import ErrException +from app.lib.response.admin_return import AdminReturn +from app.request.admin.login_request import LoginRequest +from app.service.admin.base_admin_service import BaseAdminService +from app.service.base_token_service import BaseTokenService + + +class LoginService(BaseAdminService): + def __init__( + self, + user_repository: AdminUserRepository, + token_service: BaseTokenService, + admin_return: AdminReturn, + ) -> None: + super().__init__(admin_return) + self.user_repository = user_repository + self.token_service = token_service + + async def handle(self, payload: LoginRequest, request: Request) -> dict: + admin_info = await self.user_repository.find_by_username(payload.username) + if admin_info is None: + raise ErrException("后台管理员不存在") + + password_valid = await asyncio.to_thread( + admin_info.verify_password, + payload.password, + ) + if not password_valid: + raise ErrException("密码错误") + + if admin_info.status == AdminUserStatusCode.DISABLE: + raise ErrException("用户已禁用") + + await self.user_repository.record_login(admin_info.id, self._client_ip(request)) + jwt = self.token_service.get_jwt("admin") + return self.admin_return.success( + "success", + { + "access_token": jwt.builder_access_token(str(admin_info.id)), + "refresh_token": jwt.builder_refresh_token(str(admin_info.id)), + "expire_at": int(jwt.get_config("ttl", 0)), + }, + ) + + @staticmethod + def _client_ip(request: Request) -> str: + forwarded = request.headers.get("x-forwarded-for") + if forwarded: + return forwarded.split(",", 1)[0].strip() + return request.client.host if request.client else "" diff --git a/app/service/admin/login/refresh_service.py b/app/service/admin/login/refresh_service.py new file mode 100644 index 0000000..22ba272 --- /dev/null +++ b/app/service/admin/login/refresh_service.py @@ -0,0 +1,22 @@ +from app.lib.jwt.token import JwtToken +from app.lib.response.admin_return import AdminReturn +from app.service.admin.base_admin_service import BaseAdminService +from app.service.base_token_service import BaseTokenService + + +class RefreshService(BaseAdminService): + def __init__(self, token_service: BaseTokenService, admin_return: AdminReturn) -> None: + super().__init__(admin_return) + self.token_service = token_service + + async def handle(self, token: JwtToken) -> dict: + return self.admin_return.success("success", await self.refresh_token(token)) + + async def refresh_token(self, token: JwtToken) -> dict[str, int | str]: + jwt = self.token_service.get_jwt("admin") + await jwt.add_blacklist(token) + return { + "access_token": jwt.builder_access_token(token.jwt_id), + "refresh_token": jwt.builder_refresh_token(token.jwt_id), + "expire_at": int(jwt.get_config("ttl", 0)), + } diff --git a/app/service/admin/profile/__init__.py b/app/service/admin/profile/__init__.py new file mode 100644 index 0000000..c5c1564 --- /dev/null +++ b/app/service/admin/profile/__init__.py @@ -0,0 +1 @@ +"""Admin profile services.""" diff --git a/app/service/admin/profile/current_user_service.py b/app/service/admin/profile/current_user_service.py new file mode 100644 index 0000000..95fef60 --- /dev/null +++ b/app/service/admin/profile/current_user_service.py @@ -0,0 +1,20 @@ +from app.common.repository.admin_user_repository import AdminUserRepository +from app.exception.err_exception import ErrException +from app.lib.response.admin_return import AdminReturn +from app.service.admin.base_admin_service import BaseAdminService + + +class CurrentUserService(BaseAdminService): + def __init__( + self, + user_repository: AdminUserRepository, + admin_return: AdminReturn, + ) -> None: + super().__init__(admin_return) + self.user_repository = user_repository + + async def handle(self) -> dict: + admin_user = await self.user_repository.find_by_id(self.admin_id) + if admin_user is None: + raise ErrException("账户不存在") + return self.admin_return.success("success", admin_user.to_public_dict()) diff --git a/app/service/base_token_service.py b/app/service/base_token_service.py new file mode 100644 index 0000000..d1d5528 --- /dev/null +++ b/app/service/base_token_service.py @@ -0,0 +1,17 @@ +from app.constants.result_code import ResultCode +from app.exception.err_exception import ErrException +from app.lib.jwt.factory import JwtFactory +from app.lib.jwt.jwt import Jwt +from app.lib.jwt.token import JwtToken + + +class BaseTokenService: + def __init__(self, jwt_factory: JwtFactory) -> None: + self.jwt_factory = jwt_factory + + def get_jwt(self, scene: str = "admin") -> Jwt: + return self.jwt_factory.get(scene) + + async def check_jwt(self, jwt: Jwt, token: JwtToken) -> None: + if await jwt.has_blacklist(token): + raise ErrException("token已过期", ResultCode.JWT_EXPIRED) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..ed78f77 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,14 @@ +[project] +name = "py-server" +version = "0.1.0" +description = "Async API server scaffold inspired by the Hyperf app layering." +requires-python = ">=3.14" +dependencies = [ + "fastapi>=0.136.0", + "pydantic-settings>=2.14.0", + "uvicorn[standard]>=0.49.0", +] + +[tool.pyright] +venvPath = "." +venv = ".venv" diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..38bb211 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1 @@ +"""Test package.""" diff --git a/tests/test_admin_login_flow.py b/tests/test_admin_login_flow.py new file mode 100644 index 0000000..e895b4d --- /dev/null +++ b/tests/test_admin_login_flow.py @@ -0,0 +1,88 @@ +import os +import tempfile +import unittest +from pathlib import Path + +DB_PATH = Path(tempfile.gettempdir()) / "py_server_admin_login_test.db" +DB_PATH.unlink(missing_ok=True) +os.environ["DATABASE_PATH"] = str(DB_PATH) +os.environ["JWT_ADMIN_SECRET"] = "test_admin_secret" +os.environ["ADMIN_SEED_USERNAME"] = "admin" +os.environ["ADMIN_SEED_PASSWORD"] = "admin" + +from httpx import ASGITransport, AsyncClient + +from app.core.dependencies import bootstrap_database +from app.main import app + + +class AdminLoginFlowTest(unittest.IsolatedAsyncioTestCase): + async def asyncSetUp(self) -> None: + await bootstrap_database() + self.client = AsyncClient( + transport=ASGITransport(app=app), + base_url="http://testserver", + ) + + async def asyncTearDown(self) -> None: + await self.client.aclose() + + async def test_login_access_and_refresh_flow(self) -> None: + login_response = await self.client.post( + "/admin/login/login", + json={"username": "admin", "password": "admin"}, + ) + login_payload = login_response.json() + + self.assertEqual(login_payload["code"], 0) + self.assertIn("access_token", login_payload["data"]) + self.assertIn("refresh_token", login_payload["data"]) + self.assertEqual(login_payload["data"]["expire_at"], 3600) + + access_token = login_payload["data"]["access_token"] + refresh_token = login_payload["data"]["refresh_token"] + + current_response = await self.client.get( + "/admin/profile/current", + headers={"Authorization": f"Bearer {access_token}"}, + ) + current_payload = current_response.json() + + self.assertEqual(current_payload["code"], 0) + self.assertEqual(current_payload["data"]["username"], "admin") + + refresh_response = await self.client.post( + "/admin/login/refresh", + headers={"Authorization": f"Bearer {refresh_token}"}, + ) + refresh_payload = refresh_response.json() + + self.assertEqual(refresh_payload["code"], 0) + self.assertNotEqual(refresh_payload["data"]["refresh_token"], refresh_token) + + reused_response = await self.client.post( + "/admin/login/refresh", + headers={"Authorization": f"Bearer {refresh_token}"}, + ) + reused_payload = reused_response.json() + + self.assertEqual(reused_payload["code"], 10001) + + async def test_refresh_endpoint_rejects_access_token(self) -> None: + login_response = await self.client.post( + "/admin/login/login", + json={"username": "admin", "password": "admin"}, + ) + access_token = login_response.json()["data"]["access_token"] + + refresh_response = await self.client.post( + "/admin/login/refresh", + headers={"Authorization": f"Bearer {access_token}"}, + ) + refresh_payload = refresh_response.json() + + self.assertEqual(refresh_payload["code"], 10002) + + +if __name__ == "__main__": + unittest.main()