Initial FastAPI admin auth scaffold
This commit is contained in:
11
.env.example
Normal file
11
.env.example
Normal file
@@ -0,0 +1,11 @@
|
||||
APP_NAME=py_server
|
||||
APP_ENV=local
|
||||
DATABASE_PATH=storage/app.db
|
||||
|
||||
JWT_ADMIN_SECRET=change_me_admin_secret
|
||||
ADMIN_JWT_TTL=3600
|
||||
ADMIN_JWT_REFRESH_TTL=7200
|
||||
JWT_BLACKLIST_TTL=7201
|
||||
|
||||
ADMIN_SEED_USERNAME=admin
|
||||
ADMIN_SEED_PASSWORD=admin
|
||||
7
.gitignore
vendored
Normal file
7
.gitignore
vendored
Normal file
@@ -0,0 +1,7 @@
|
||||
.env
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
.pytest_cache/
|
||||
.ruff_cache/
|
||||
.idea/
|
||||
storage/*.db
|
||||
145
README.md
Normal file
145
README.md
Normal file
@@ -0,0 +1,145 @@
|
||||
# py_server
|
||||
|
||||
这是一个 FastAPI 异步 API 项目骨架,分层方式参考了 Hyperf 项目里的 `server/app`。
|
||||
|
||||
## 分层对应
|
||||
|
||||
| Hyperf | FastAPI 当前项目 | 作用 |
|
||||
| --- | --- | --- |
|
||||
| `Controller/Admin` | `app/controller/admin` | 接收请求,调用 service |
|
||||
| `Service/Admin` | `app/service/admin` | 业务逻辑 |
|
||||
| `Common/Repository` | `app/common/repository` | 数据访问 |
|
||||
| `Model` | `app/model` | 数据模型 |
|
||||
| `Middleware/Admin` | `app/middleware/admin` | admin token 和权限校验 |
|
||||
| `Lib/Jwt` | `app/lib/jwt` | JWT 签发、解析、黑名单 |
|
||||
| `Request/Admin` | `app/request/admin` | 请求参数校验 |
|
||||
| `Constants` | `app/constants` | 状态码和业务常量 |
|
||||
| `Lib/Return` | `app/lib/response` | 统一返回结构 |
|
||||
|
||||
> Python 里 `return` 是关键字,所以参考 Hyperf 的 `Lib/Return` 在这里命名为 `lib/response`。
|
||||
|
||||
## 启动
|
||||
|
||||
```bash
|
||||
.venv/bin/uvicorn app.main:app --reload
|
||||
```
|
||||
|
||||
默认启动时会自动创建 SQLite 开发库,并创建一个 admin 用户:
|
||||
|
||||
- username: `admin`
|
||||
- password: `admin`
|
||||
|
||||
本地配置可以复制:
|
||||
|
||||
```bash
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
## Admin 登录接口
|
||||
|
||||
### 登录
|
||||
|
||||
```http
|
||||
POST /admin/login/login
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"username": "admin",
|
||||
"password": "admin"
|
||||
}
|
||||
```
|
||||
|
||||
返回:
|
||||
|
||||
```json
|
||||
{
|
||||
"code": 0,
|
||||
"message": "success",
|
||||
"data": {
|
||||
"access_token": "...",
|
||||
"refresh_token": "...",
|
||||
"expire_at": 3600
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 刷新 token
|
||||
|
||||
```http
|
||||
POST /admin/login/refresh
|
||||
Authorization: Bearer <refresh_token>
|
||||
```
|
||||
|
||||
刷新成功后会签发新的 `access_token` 和 `refresh_token`,并把旧的 `refresh_token` 加入黑名单,防止重复刷新。
|
||||
|
||||
### 当前 admin 用户
|
||||
|
||||
```http
|
||||
GET /admin/profile/current
|
||||
Authorization: Bearer <access_token>
|
||||
```
|
||||
|
||||
## JWT 逻辑
|
||||
|
||||
当前实现和参考 Hyperf 项目保持同样思路:
|
||||
|
||||
1. 登录成功后同时签发 `access_token` 和 `refresh_token`。
|
||||
2. 普通 admin 接口只接受 `access_token`。
|
||||
3. `/admin/login/refresh` 只接受 `refresh_token`。
|
||||
4. refresh 成功后,把旧 refresh token 加入黑名单。
|
||||
5. token 会校验签名、过期时间、issuer 和 token 类型。
|
||||
|
||||
token 读取顺序和 Hyperf 中间件一致:
|
||||
|
||||
1. `Authorization: Bearer <token>`
|
||||
2. `token` header
|
||||
3. query string: `?token=...`
|
||||
|
||||
## FastAPI 依赖注入
|
||||
|
||||
Controller 里这段:
|
||||
|
||||
```python
|
||||
service: LoginService = Depends(get_login_service)
|
||||
```
|
||||
|
||||
可以理解为 FastAPI 版的 Hyperf 容器注入。
|
||||
|
||||
请求进来时,FastAPI 会先调用 `get_login_service()`,把创建好的 `LoginService` 传给 `service` 参数。
|
||||
|
||||
`get_login_service()` 内部会继续组装:
|
||||
|
||||
- `AdminUserRepository`
|
||||
- `BaseTokenService`
|
||||
- `AdminReturn`
|
||||
|
||||
所以 controller 只负责接收请求和调用 service。
|
||||
|
||||
## `@lru_cache` 的作用
|
||||
|
||||
`app/core/dependencies.py` 里的:
|
||||
|
||||
```python
|
||||
@lru_cache
|
||||
def get_database() -> Database:
|
||||
return Database(get_settings().database_path)
|
||||
```
|
||||
|
||||
表示第一次调用时创建对象,后面再次调用时直接复用第一次创建的对象。
|
||||
|
||||
在这个项目里,它的作用类似 Hyperf 容器里的共享服务/单例服务。比如 JWT 黑名单必须复用同一个对象,否则旧 refresh token 加入黑名单后,下一次请求就查不到了。
|
||||
|
||||
## 测试
|
||||
|
||||
```bash
|
||||
.venv/bin/python -m compileall app tests
|
||||
.venv/bin/python -m unittest tests.test_admin_login_flow
|
||||
```
|
||||
|
||||
测试覆盖:
|
||||
|
||||
- admin 登录成功
|
||||
- access token 访问当前用户成功
|
||||
- refresh token 换新 token 成功
|
||||
- 旧 refresh token 复用失败
|
||||
- access token 调用 refresh 接口失败
|
||||
1
app/__init__.py
Normal file
1
app/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Application package."""
|
||||
1
app/common/__init__.py
Normal file
1
app/common/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Common helpers."""
|
||||
3
app/common/context.py
Normal file
3
app/common/context.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from contextvars import ContextVar
|
||||
|
||||
current_admin_id: ContextVar[int] = ContextVar("current_admin_id", default=0)
|
||||
1
app/common/repository/__init__.py
Normal file
1
app/common/repository/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Repository layer."""
|
||||
79
app/common/repository/admin_user_repository.py
Normal file
79
app/common/repository/admin_user_repository.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from app.common.repository.base_repository import BaseRepository
|
||||
from app.common.security.password_hasher import hash_password
|
||||
from app.constants.model.admin_user.admin_user_status_code import AdminUserStatusCode
|
||||
from app.model.admin_user import AdminUser
|
||||
|
||||
|
||||
class AdminUserRepository(BaseRepository):
|
||||
async def find_by_username(self, username: str) -> AdminUser | None:
|
||||
row = await self.database.fetchone(
|
||||
"""
|
||||
SELECT id, username, password, user_type, nickname, phone, email,
|
||||
status, login_ip, login_time, remark
|
||||
FROM admin_user
|
||||
WHERE username = ?
|
||||
LIMIT 1
|
||||
""",
|
||||
(username,),
|
||||
)
|
||||
return AdminUser.from_row(row) if row else None
|
||||
|
||||
async def find_by_id(self, admin_id: int) -> AdminUser | None:
|
||||
row = await self.database.fetchone(
|
||||
"""
|
||||
SELECT id, username, password, user_type, nickname, phone, email,
|
||||
status, login_ip, login_time, remark
|
||||
FROM admin_user
|
||||
WHERE id = ?
|
||||
LIMIT 1
|
||||
""",
|
||||
(admin_id,),
|
||||
)
|
||||
return AdminUser.from_row(row) if row else None
|
||||
|
||||
async def record_login(self, admin_id: int, login_ip: str) -> None:
|
||||
now = self._now()
|
||||
await self.database.execute(
|
||||
"""
|
||||
UPDATE admin_user
|
||||
SET login_ip = ?, login_time = ?, updated_at = ?
|
||||
WHERE id = ?
|
||||
""",
|
||||
(login_ip, now, now, admin_id),
|
||||
)
|
||||
|
||||
async def ensure_seed_admin(self, username: str, password: str) -> None:
|
||||
exists = await self.find_by_username(username)
|
||||
if exists is not None:
|
||||
return
|
||||
|
||||
now = self._now()
|
||||
await self.database.execute(
|
||||
"""
|
||||
INSERT INTO admin_user (
|
||||
username, password, user_type, nickname, phone, email, status,
|
||||
login_ip, login_time, created_at, updated_at, remark
|
||||
)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
username,
|
||||
hash_password(password),
|
||||
"SuperAdmin",
|
||||
"Super Admin",
|
||||
"",
|
||||
"",
|
||||
int(AdminUserStatusCode.NORMAL),
|
||||
"",
|
||||
"",
|
||||
now,
|
||||
now,
|
||||
"seeded by application startup",
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _now() -> str:
|
||||
return datetime.now(UTC).strftime("%Y-%m-%d %H:%M:%S")
|
||||
6
app/common/repository/base_repository.py
Normal file
6
app/common/repository/base_repository.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from app.core.database import Database
|
||||
|
||||
|
||||
class BaseRepository:
|
||||
def __init__(self, database: Database) -> None:
|
||||
self.database = database
|
||||
1
app/common/security/__init__.py
Normal file
1
app/common/security/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Security helpers."""
|
||||
40
app/common/security/password_hasher.py
Normal file
40
app/common/security/password_hasher.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import secrets
|
||||
|
||||
ALGORITHM = "pbkdf2_sha256"
|
||||
ITERATIONS = 390_000
|
||||
|
||||
|
||||
def _b64encode(raw: bytes) -> str:
|
||||
return base64.urlsafe_b64encode(raw).decode("ascii").rstrip("=")
|
||||
|
||||
|
||||
def _b64decode(value: str) -> bytes:
|
||||
padding = "=" * (-len(value) % 4)
|
||||
return base64.urlsafe_b64decode(value + padding)
|
||||
|
||||
|
||||
def hash_password(password: str, iterations: int = ITERATIONS) -> str:
|
||||
salt = secrets.token_bytes(16)
|
||||
digest = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, iterations)
|
||||
return f"{ALGORITHM}${iterations}${_b64encode(salt)}${_b64encode(digest)}"
|
||||
|
||||
|
||||
def verify_password(password: str, stored_hash: str) -> bool:
|
||||
try:
|
||||
algorithm, iterations, salt, expected = stored_hash.split("$", 3)
|
||||
except ValueError:
|
||||
return hmac.compare_digest(password, stored_hash)
|
||||
|
||||
if algorithm != ALGORITHM:
|
||||
return False
|
||||
|
||||
digest = hashlib.pbkdf2_hmac(
|
||||
"sha256",
|
||||
password.encode("utf-8"),
|
||||
_b64decode(salt),
|
||||
int(iterations),
|
||||
)
|
||||
return hmac.compare_digest(_b64encode(digest), expected)
|
||||
1
app/constants/__init__.py
Normal file
1
app/constants/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Application constants."""
|
||||
6
app/constants/admin_code.py
Normal file
6
app/constants/admin_code.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from enum import IntEnum
|
||||
|
||||
|
||||
class AdminCode(IntEnum):
|
||||
DISABLED = 30001
|
||||
FORBIDDEN = 30002
|
||||
1
app/constants/model/__init__.py
Normal file
1
app/constants/model/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Model constants."""
|
||||
1
app/constants/model/admin_user/__init__.py
Normal file
1
app/constants/model/admin_user/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Admin user constants."""
|
||||
12
app/constants/model/admin_user/admin_user_status_code.py
Normal file
12
app/constants/model/admin_user/admin_user_status_code.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from enum import IntEnum
|
||||
|
||||
|
||||
class AdminUserStatusCode(IntEnum):
|
||||
NORMAL = 1
|
||||
DISABLE = 2
|
||||
|
||||
def is_normal(self) -> bool:
|
||||
return self is AdminUserStatusCode.NORMAL
|
||||
|
||||
def is_disable(self) -> bool:
|
||||
return self is AdminUserStatusCode.DISABLE
|
||||
12
app/constants/result_code.py
Normal file
12
app/constants/result_code.py
Normal file
@@ -0,0 +1,12 @@
|
||||
from enum import IntEnum
|
||||
|
||||
|
||||
class ResultCode(IntEnum):
|
||||
SUCCESS = 0
|
||||
ERROR = 1
|
||||
JWT_EXPIRED = 10001
|
||||
JWT_ERROR = 10002
|
||||
OLD_PASSWORD_ERROR = 10003
|
||||
ACCOUNT_DEACTIVATING = 20001
|
||||
ACCOUNT_DEACTIVATED = 20002
|
||||
ACCOUNT_CANNOT_DEACTIVATE = 20003
|
||||
1
app/controller/__init__.py
Normal file
1
app/controller/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Controller layer."""
|
||||
1
app/controller/admin/__init__.py
Normal file
1
app/controller/admin/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Admin controllers."""
|
||||
32
app/controller/admin/login_controller.py
Normal file
32
app/controller/admin/login_controller.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from app.core.dependencies import get_login_service, get_refresh_service
|
||||
from app.lib.jwt.token import JwtToken
|
||||
from app.middleware.admin.refresh_admin_token_middleware import RefreshAdminTokenMiddleware
|
||||
from app.request.admin.login_request import LoginRequest
|
||||
from app.service.admin.login.login_service import LoginService
|
||||
from app.service.admin.login.refresh_service import RefreshService
|
||||
|
||||
router = APIRouter(prefix="/admin/login", tags=["admin-login"])
|
||||
|
||||
|
||||
@router.post("/login")
|
||||
async def login(
|
||||
payload: LoginRequest,
|
||||
request: Request,
|
||||
# FastAPI 的 Depends 类似 Hyperf 里的容器注入。
|
||||
# 请求进入这个接口时,框架会先调用 get_login_service(),
|
||||
# 把组装好的 LoginService 传进 service 参数。
|
||||
service: LoginService = Depends(get_login_service),
|
||||
) -> dict:
|
||||
return await service.handle(payload, request)
|
||||
|
||||
|
||||
@router.post("/refresh")
|
||||
async def refresh(
|
||||
# 这里把 RefreshAdminTokenMiddleware 当成依赖使用。
|
||||
# FastAPI 会先执行 refresh token 校验,通过后才进入 controller。
|
||||
token: JwtToken = Depends(RefreshAdminTokenMiddleware()),
|
||||
service: RefreshService = Depends(get_refresh_service),
|
||||
) -> dict:
|
||||
return await service.handle(token)
|
||||
21
app/controller/admin/profile_controller.py
Normal file
21
app/controller/admin/profile_controller.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from app.core.dependencies import get_current_user_service
|
||||
from app.middleware.admin.admin_token_middleware import AdminTokenMiddleware
|
||||
from app.middleware.admin.permission_middleware import PermissionMiddleware
|
||||
from app.service.admin.profile.current_user_service import CurrentUserService
|
||||
|
||||
router = APIRouter(prefix="/admin/profile", tags=["admin-profile"])
|
||||
|
||||
|
||||
@router.get(
|
||||
"/current",
|
||||
dependencies=[
|
||||
Depends(AdminTokenMiddleware()),
|
||||
Depends(PermissionMiddleware()),
|
||||
],
|
||||
)
|
||||
async def current(
|
||||
service: CurrentUserService = Depends(get_current_user_service),
|
||||
) -> dict:
|
||||
return await service.handle()
|
||||
1
app/controller/api/__init__.py
Normal file
1
app/controller/api/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Frontend API controllers."""
|
||||
10
app/controller/api/health_controller.py
Normal file
10
app/controller/api/health_controller.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
from app.lib.response.admin_return import AdminReturn
|
||||
|
||||
router = APIRouter(prefix="/api", tags=["api"])
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health() -> dict:
|
||||
return AdminReturn().success("success", {"status": "ok"})
|
||||
1
app/core/__init__.py
Normal file
1
app/core/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Core application wiring."""
|
||||
39
app/core/config.py
Normal file
39
app/core/config.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
extra="ignore",
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
app_name: str = Field(default="py_server", alias="APP_NAME")
|
||||
app_env: str = Field(default="local", alias="APP_ENV")
|
||||
database_path: Path = Field(default=Path("storage/app.db"), alias="DATABASE_PATH")
|
||||
|
||||
admin_jwt_secret: str = Field(
|
||||
default="dev_admin_secret_change_me",
|
||||
alias="JWT_ADMIN_SECRET",
|
||||
)
|
||||
admin_jwt_ttl: int = Field(default=3600, alias="ADMIN_JWT_TTL")
|
||||
admin_jwt_refresh_ttl: int = Field(default=7200, alias="ADMIN_JWT_REFRESH_TTL")
|
||||
jwt_blacklist_ttl: int = Field(default=7201, alias="JWT_BLACKLIST_TTL")
|
||||
|
||||
api_jwt_secret: str = Field(default="dev_api_secret_change_me", alias="JWT_SECRET")
|
||||
api_jwt_ttl: int = Field(default=3600, alias="JWT_TTL")
|
||||
api_jwt_refresh_ttl: int = Field(default=7200, alias="JWT_REFRESH_TTL")
|
||||
|
||||
admin_seed_username: str = Field(default="admin", alias="ADMIN_SEED_USERNAME")
|
||||
admin_seed_password: str = Field(default="admin", alias="ADMIN_SEED_PASSWORD")
|
||||
cors_allow_origins: list[str] = Field(default=["*"], alias="CORS_ALLOW_ORIGINS")
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_settings() -> Settings:
|
||||
return Settings()
|
||||
87
app/core/database.py
Normal file
87
app/core/database.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import asyncio
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
class Database:
|
||||
def __init__(self, path: Path) -> None:
|
||||
self.path = path
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await asyncio.to_thread(self._initialize)
|
||||
|
||||
async def execute(self, sql: str, params: tuple[Any, ...] = ()) -> int:
|
||||
return await asyncio.to_thread(self._execute, sql, params)
|
||||
|
||||
async def fetchone(
|
||||
self,
|
||||
sql: str,
|
||||
params: tuple[Any, ...] = (),
|
||||
) -> dict[str, Any] | None:
|
||||
return await asyncio.to_thread(self._fetchone, sql, params)
|
||||
|
||||
async def fetchall(
|
||||
self,
|
||||
sql: str,
|
||||
params: tuple[Any, ...] = (),
|
||||
) -> list[dict[str, Any]]:
|
||||
return await asyncio.to_thread(self._fetchall, sql, params)
|
||||
|
||||
def _connect(self) -> sqlite3.Connection:
|
||||
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||
conn = sqlite3.connect(self.path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
conn.execute("PRAGMA foreign_keys = ON")
|
||||
return conn
|
||||
|
||||
def _initialize(self) -> None:
|
||||
conn = self._connect()
|
||||
try:
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS admin_user (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username TEXT NOT NULL UNIQUE,
|
||||
password TEXT NOT NULL,
|
||||
user_type TEXT NOT NULL DEFAULT 'admin',
|
||||
nickname TEXT NOT NULL DEFAULT '',
|
||||
phone TEXT NOT NULL DEFAULT '',
|
||||
email TEXT NOT NULL DEFAULT '',
|
||||
status INTEGER NOT NULL DEFAULT 1,
|
||||
login_ip TEXT NOT NULL DEFAULT '',
|
||||
login_time TEXT NOT NULL DEFAULT '',
|
||||
created_at TEXT NOT NULL,
|
||||
updated_at TEXT NOT NULL,
|
||||
remark TEXT NOT NULL DEFAULT ''
|
||||
)
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _execute(self, sql: str, params: tuple[Any, ...]) -> int:
|
||||
conn = self._connect()
|
||||
try:
|
||||
cursor = conn.execute(sql, params)
|
||||
conn.commit()
|
||||
return int(cursor.lastrowid or 0)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _fetchone(self, sql: str, params: tuple[Any, ...]) -> dict[str, Any] | None:
|
||||
conn = self._connect()
|
||||
try:
|
||||
row = conn.execute(sql, params).fetchone()
|
||||
return dict(row) if row else None
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
def _fetchall(self, sql: str, params: tuple[Any, ...]) -> list[dict[str, Any]]:
|
||||
conn = self._connect()
|
||||
try:
|
||||
rows = conn.execute(sql, params).fetchall()
|
||||
return [dict(row) for row in rows]
|
||||
finally:
|
||||
conn.close()
|
||||
70
app/core/dependencies.py
Normal file
70
app/core/dependencies.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from functools import lru_cache
|
||||
|
||||
from app.common.repository.admin_user_repository import AdminUserRepository
|
||||
from app.core.config import Settings, get_settings
|
||||
from app.core.database import Database
|
||||
from app.lib.jwt.blacklist import InMemoryTokenBlacklist
|
||||
from app.lib.jwt.factory import JwtFactory
|
||||
from app.lib.response.admin_return import AdminReturn
|
||||
from app.service.admin.login.login_service import LoginService
|
||||
from app.service.admin.login.refresh_service import RefreshService
|
||||
from app.service.admin.profile.current_user_service import CurrentUserService
|
||||
from app.service.base_token_service import BaseTokenService
|
||||
|
||||
|
||||
# lru_cache 会缓存函数第一次创建出来的对象。
|
||||
# 这里用它把 Database、JwtFactory、TokenService 等依赖做成应用级单例,
|
||||
# 类似 Hyperf 从容器里反复 get 同一个共享服务。
|
||||
@lru_cache
|
||||
def get_database() -> Database:
|
||||
return Database(get_settings().database_path)
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_token_blacklist() -> InMemoryTokenBlacklist:
|
||||
return InMemoryTokenBlacklist()
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_jwt_factory() -> JwtFactory:
|
||||
return JwtFactory(get_settings(), get_token_blacklist())
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_token_service() -> BaseTokenService:
|
||||
return BaseTokenService(get_jwt_factory())
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_admin_return() -> AdminReturn:
|
||||
return AdminReturn()
|
||||
|
||||
|
||||
@lru_cache
|
||||
def get_admin_user_repository() -> AdminUserRepository:
|
||||
return AdminUserRepository(get_database())
|
||||
|
||||
|
||||
def get_login_service() -> LoginService:
|
||||
return LoginService(
|
||||
get_admin_user_repository(),
|
||||
get_token_service(),
|
||||
get_admin_return(),
|
||||
)
|
||||
|
||||
|
||||
def get_refresh_service() -> RefreshService:
|
||||
return RefreshService(get_token_service(), get_admin_return())
|
||||
|
||||
|
||||
def get_current_user_service() -> CurrentUserService:
|
||||
return CurrentUserService(get_admin_user_repository(), get_admin_return())
|
||||
|
||||
|
||||
async def bootstrap_database(settings: Settings | None = None) -> None:
|
||||
settings = settings or get_settings()
|
||||
await get_database().initialize()
|
||||
await get_admin_user_repository().ensure_seed_admin(
|
||||
settings.admin_seed_username,
|
||||
settings.admin_seed_password,
|
||||
)
|
||||
1
app/exception/__init__.py
Normal file
1
app/exception/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Exception layer."""
|
||||
14
app/exception/err_exception.py
Normal file
14
app/exception/err_exception.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from app.constants.result_code import ResultCode
|
||||
|
||||
|
||||
class ErrException(Exception):
|
||||
def __init__(
|
||||
self,
|
||||
message: str = "failed",
|
||||
code: int | ResultCode = ResultCode.ERROR,
|
||||
data: dict | None = None,
|
||||
) -> None:
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
self.code = int(code)
|
||||
self.data = data or {}
|
||||
34
app/exception/handler.py
Normal file
34
app/exception/handler.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from app.constants.result_code import ResultCode
|
||||
from app.exception.err_exception import ErrException
|
||||
from app.lib.response.admin_return import AdminReturn
|
||||
|
||||
|
||||
async def err_exception_handler(request: Request, exc: ErrException) -> JSONResponse:
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content=AdminReturn().error(exc.message, exc.code, exc.data),
|
||||
)
|
||||
|
||||
|
||||
async def validation_exception_handler(
|
||||
request: Request,
|
||||
exc: RequestValidationError,
|
||||
) -> JSONResponse:
|
||||
first_error = exc.errors()[0] if exc.errors() else {}
|
||||
location = ".".join(str(item) for item in first_error.get("loc", []))
|
||||
message = first_error.get("msg", "参数错误")
|
||||
if location:
|
||||
message = f"{location}: {message}"
|
||||
return JSONResponse(
|
||||
status_code=200,
|
||||
content=AdminReturn().error(message, ResultCode.ERROR),
|
||||
)
|
||||
|
||||
|
||||
def register_exception_handlers(app: FastAPI) -> None:
|
||||
app.add_exception_handler(ErrException, err_exception_handler)
|
||||
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||
1
app/lib/__init__.py
Normal file
1
app/lib/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Library layer."""
|
||||
1
app/lib/jwt/__init__.py
Normal file
1
app/lib/jwt/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""JWT helpers."""
|
||||
25
app/lib/jwt/blacklist.py
Normal file
25
app/lib/jwt/blacklist.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
|
||||
class InMemoryTokenBlacklist:
|
||||
def __init__(self) -> None:
|
||||
self._items: dict[str, float] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def add(self, token: str, ttl: int) -> None:
|
||||
now = time.time()
|
||||
async with self._lock:
|
||||
self._cleanup(now)
|
||||
self._items[token] = now + ttl
|
||||
|
||||
async def has(self, token: str) -> bool:
|
||||
now = time.time()
|
||||
async with self._lock:
|
||||
self._cleanup(now)
|
||||
return token in self._items
|
||||
|
||||
def _cleanup(self, now: float) -> None:
|
||||
expired = [token for token, expires_at in self._items.items() if expires_at <= now]
|
||||
for token in expired:
|
||||
self._items.pop(token, None)
|
||||
10
app/lib/jwt/exceptions.py
Normal file
10
app/lib/jwt/exceptions.py
Normal file
@@ -0,0 +1,10 @@
|
||||
class JwtError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class JwtExpiredError(JwtError):
|
||||
pass
|
||||
|
||||
|
||||
class JwtInvalidError(JwtError):
|
||||
pass
|
||||
34
app/lib/jwt/factory.py
Normal file
34
app/lib/jwt/factory.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from app.core.config import Settings
|
||||
from app.lib.jwt.blacklist import InMemoryTokenBlacklist
|
||||
from app.lib.jwt.jwt import Jwt, JwtSceneConfig
|
||||
|
||||
|
||||
class JwtFactory:
|
||||
def __init__(self, settings: Settings, blacklist: InMemoryTokenBlacklist) -> None:
|
||||
self.settings = settings
|
||||
self.blacklist = blacklist
|
||||
|
||||
def get(self, scene: str = "default") -> Jwt:
|
||||
if scene == "admin":
|
||||
return Jwt(
|
||||
JwtSceneConfig(
|
||||
secret=self.settings.admin_jwt_secret,
|
||||
issuer=f"{self.settings.app_name}_admin",
|
||||
ttl=self.settings.admin_jwt_ttl,
|
||||
refresh_ttl=self.settings.admin_jwt_refresh_ttl,
|
||||
blacklist_ttl=self.settings.jwt_blacklist_ttl,
|
||||
),
|
||||
self.blacklist,
|
||||
)
|
||||
|
||||
scene_name = "default" if scene == "default" else scene
|
||||
return Jwt(
|
||||
JwtSceneConfig(
|
||||
secret=self.settings.api_jwt_secret,
|
||||
issuer=f"{self.settings.app_name}_{scene_name}",
|
||||
ttl=self.settings.api_jwt_ttl,
|
||||
refresh_ttl=self.settings.api_jwt_refresh_ttl,
|
||||
blacklist_ttl=self.settings.jwt_blacklist_ttl,
|
||||
),
|
||||
self.blacklist,
|
||||
)
|
||||
151
app/lib/jwt/jwt.py
Normal file
151
app/lib/jwt/jwt.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import secrets
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from app.lib.jwt.blacklist import InMemoryTokenBlacklist
|
||||
from app.lib.jwt.exceptions import JwtExpiredError, JwtInvalidError
|
||||
from app.lib.jwt.token import JwtToken
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class JwtSceneConfig:
|
||||
secret: str
|
||||
issuer: str
|
||||
ttl: int
|
||||
refresh_ttl: int
|
||||
blacklist_ttl: int
|
||||
|
||||
|
||||
class Jwt:
|
||||
def __init__(self, config: JwtSceneConfig, blacklist: InMemoryTokenBlacklist) -> None:
|
||||
self.config = config
|
||||
self.blacklist = blacklist
|
||||
|
||||
@property
|
||||
def issuer(self) -> str:
|
||||
return self.config.issuer
|
||||
|
||||
def builder_access_token(self, admin_id: str) -> str:
|
||||
return self._build_token(admin_id, token_type="access", ttl=self.config.ttl)
|
||||
|
||||
def builder_refresh_token(self, admin_id: str) -> str:
|
||||
return self._build_token(admin_id, token_type="refresh", ttl=self.config.refresh_ttl)
|
||||
|
||||
async def parser_access_token(self, raw_token: str) -> JwtToken:
|
||||
token = self._parse(raw_token)
|
||||
if token.is_refresh:
|
||||
raise JwtInvalidError("Token is a refresh token")
|
||||
self._validate_issuer(token)
|
||||
return token
|
||||
|
||||
async def parser_refresh_token(self, raw_token: str) -> JwtToken:
|
||||
token = self._parse(raw_token)
|
||||
if not token.is_refresh:
|
||||
raise JwtInvalidError("Token is not a refresh token")
|
||||
self._validate_issuer(token)
|
||||
return token
|
||||
|
||||
async def add_blacklist(self, token: JwtToken) -> None:
|
||||
await self.blacklist.add(token.raw, self.config.blacklist_ttl)
|
||||
|
||||
async def has_blacklist(self, token: JwtToken) -> bool:
|
||||
return await self.blacklist.has(token.raw)
|
||||
|
||||
def get_config(self, key: str, default: Any = None) -> Any:
|
||||
return getattr(self.config, key, default)
|
||||
|
||||
def _build_token(self, admin_id: str, token_type: str, ttl: int) -> str:
|
||||
now = int(time.time())
|
||||
payload: dict[str, Any] = {
|
||||
"iss": self.config.issuer,
|
||||
"jti": str(admin_id),
|
||||
"iat": now,
|
||||
"nbf": now,
|
||||
"exp": now + ttl,
|
||||
"token_type": token_type,
|
||||
"nonce": secrets.token_urlsafe(16),
|
||||
}
|
||||
if token_type == "refresh":
|
||||
payload["sub"] = "refresh"
|
||||
|
||||
header = {"typ": "JWT", "alg": "HS256"}
|
||||
signing_input = ".".join(
|
||||
[
|
||||
self._b64url_json(header),
|
||||
self._b64url_json(payload),
|
||||
]
|
||||
)
|
||||
signature = hmac.new(
|
||||
self.config.secret.encode("utf-8"),
|
||||
signing_input.encode("ascii"),
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
return f"{signing_input}.{self._b64url_encode(signature)}"
|
||||
|
||||
def _parse(self, raw_token: str) -> JwtToken:
|
||||
try:
|
||||
header_b64, payload_b64, signature_b64 = raw_token.split(".", 2)
|
||||
except ValueError as exc:
|
||||
raise JwtInvalidError("Malformed token") from exc
|
||||
|
||||
signing_input = f"{header_b64}.{payload_b64}".encode("ascii")
|
||||
expected = hmac.new(
|
||||
self.config.secret.encode("utf-8"),
|
||||
signing_input,
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
if not hmac.compare_digest(self._b64url_encode(expected), signature_b64):
|
||||
raise JwtInvalidError("Invalid signature")
|
||||
|
||||
try:
|
||||
header = json.loads(self._b64url_decode(header_b64))
|
||||
payload = json.loads(self._b64url_decode(payload_b64))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as exc:
|
||||
raise JwtInvalidError("Invalid token payload") from exc
|
||||
|
||||
if header.get("alg") != "HS256":
|
||||
raise JwtInvalidError("Unsupported algorithm")
|
||||
if not isinstance(payload, dict):
|
||||
raise JwtInvalidError("Invalid claims")
|
||||
|
||||
token = JwtToken(raw=raw_token, claims=payload)
|
||||
self._validate_time(token)
|
||||
return token
|
||||
|
||||
def _validate_time(self, token: JwtToken) -> None:
|
||||
now = int(time.time())
|
||||
exp = token.claims.get("exp")
|
||||
nbf = token.claims.get("nbf")
|
||||
iat = token.claims.get("iat")
|
||||
|
||||
if not isinstance(exp, int):
|
||||
raise JwtInvalidError("Missing exp")
|
||||
if exp <= now:
|
||||
raise JwtExpiredError("Token expired")
|
||||
if isinstance(nbf, int) and nbf > now:
|
||||
raise JwtInvalidError("Token not active")
|
||||
if isinstance(iat, int) and iat > now + 60:
|
||||
raise JwtInvalidError("Token issued in future")
|
||||
|
||||
def _validate_issuer(self, token: JwtToken) -> None:
|
||||
if token.claims.get("iss") != self.config.issuer:
|
||||
raise JwtInvalidError("Invalid issuer")
|
||||
|
||||
@staticmethod
|
||||
def _b64url_encode(raw: bytes) -> str:
|
||||
return base64.urlsafe_b64encode(raw).decode("ascii").rstrip("=")
|
||||
|
||||
@classmethod
|
||||
def _b64url_json(cls, value: dict[str, Any]) -> str:
|
||||
raw = json.dumps(value, separators=(",", ":"), sort_keys=True).encode("utf-8")
|
||||
return cls._b64url_encode(raw)
|
||||
|
||||
@staticmethod
|
||||
def _b64url_decode(value: str) -> str:
|
||||
padding = "=" * (-len(value) % 4)
|
||||
return base64.urlsafe_b64decode(value + padding).decode("utf-8")
|
||||
23
app/lib/jwt/token.py
Normal file
23
app/lib/jwt/token.py
Normal file
@@ -0,0 +1,23 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
|
||||
@dataclass(frozen=True, slots=True)
|
||||
class JwtToken:
|
||||
raw: str
|
||||
claims: dict[str, Any]
|
||||
|
||||
@property
|
||||
def jwt_id(self) -> str:
|
||||
return str(self.claims.get("jti", ""))
|
||||
|
||||
@property
|
||||
def admin_id(self) -> int:
|
||||
try:
|
||||
return int(self.jwt_id)
|
||||
except ValueError:
|
||||
return 0
|
||||
|
||||
@property
|
||||
def is_refresh(self) -> bool:
|
||||
return self.claims.get("sub") == "refresh" or self.claims.get("token_type") == "refresh"
|
||||
1
app/lib/response/__init__.py
Normal file
1
app/lib/response/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Response helpers."""
|
||||
6
app/lib/response/admin_return.py
Normal file
6
app/lib/response/admin_return.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from app.lib.response.common_return import CommonReturn
|
||||
|
||||
|
||||
class AdminReturn(CommonReturn):
|
||||
def after_success(self, response: dict) -> dict:
|
||||
return response
|
||||
34
app/lib/response/common_return.py
Normal file
34
app/lib/response/common_return.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from app.constants.result_code import ResultCode
|
||||
|
||||
|
||||
class CommonReturn:
|
||||
def success(
|
||||
self,
|
||||
message: str = "success",
|
||||
data: dict | list | None = None,
|
||||
code: int | ResultCode = ResultCode.SUCCESS,
|
||||
) -> dict:
|
||||
return self.after_success(
|
||||
{
|
||||
"code": int(code),
|
||||
"message": message,
|
||||
"data": data if data is not None else {},
|
||||
}
|
||||
)
|
||||
|
||||
def error(
|
||||
self,
|
||||
message: str = "failed",
|
||||
code: int | ResultCode = ResultCode.ERROR,
|
||||
data: dict | None = None,
|
||||
) -> dict:
|
||||
return self.after_success(
|
||||
{
|
||||
"code": int(code),
|
||||
"message": message,
|
||||
"data": data or {},
|
||||
}
|
||||
)
|
||||
|
||||
def after_success(self, response: dict) -> dict:
|
||||
raise NotImplementedError
|
||||
44
app/main.py
Normal file
44
app/main.py
Normal file
@@ -0,0 +1,44 @@
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncIterator
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.controller.admin.login_controller import router as admin_login_router
|
||||
from app.controller.admin.profile_controller import router as admin_profile_router
|
||||
from app.controller.api.health_controller import router as api_health_router
|
||||
from app.core.config import get_settings
|
||||
from app.core.dependencies import bootstrap_database
|
||||
from app.exception.handler import register_exception_handlers
|
||||
from app.lib.response.admin_return import AdminReturn
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
||||
await bootstrap_database()
|
||||
yield
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
settings = get_settings()
|
||||
application = FastAPI(title=settings.app_name, lifespan=lifespan)
|
||||
application.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_allow_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
register_exception_handlers(application)
|
||||
application.include_router(admin_login_router)
|
||||
application.include_router(admin_profile_router)
|
||||
application.include_router(api_health_router)
|
||||
|
||||
@application.get("/")
|
||||
async def index() -> dict:
|
||||
return AdminReturn().success("success", {"name": settings.app_name})
|
||||
|
||||
return application
|
||||
|
||||
|
||||
app = create_app()
|
||||
1
app/middleware/__init__.py
Normal file
1
app/middleware/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Middleware and route dependencies."""
|
||||
1
app/middleware/admin/__init__.py
Normal file
1
app/middleware/admin/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Admin middleware."""
|
||||
16
app/middleware/admin/admin_token_middleware.py
Normal file
16
app/middleware/admin/admin_token_middleware.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from fastapi import Request
|
||||
|
||||
from app.common.context import current_admin_id
|
||||
from app.lib.jwt.jwt import Jwt
|
||||
from app.lib.jwt.token import JwtToken
|
||||
from app.middleware.token.abstract_token_middleware import AbstractTokenMiddleware
|
||||
|
||||
|
||||
class AdminTokenMiddleware(AbstractTokenMiddleware):
|
||||
async def parser_token(self, jwt: Jwt, raw_token: str) -> JwtToken:
|
||||
return await jwt.parser_access_token(raw_token)
|
||||
|
||||
def set_context(self, request: Request, token: JwtToken) -> None:
|
||||
admin_id = token.admin_id
|
||||
current_admin_id.set(admin_id)
|
||||
request.state.current_admin_id = admin_id
|
||||
34
app/middleware/admin/permission_middleware.py
Normal file
34
app/middleware/admin/permission_middleware.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from collections.abc import Iterable
|
||||
|
||||
from fastapi import Depends, Request
|
||||
|
||||
from app.common.context import current_admin_id
|
||||
from app.common.repository.admin_user_repository import AdminUserRepository
|
||||
from app.constants.admin_code import AdminCode
|
||||
from app.constants.model.admin_user.admin_user_status_code import AdminUserStatusCode
|
||||
from app.core.dependencies import get_admin_user_repository
|
||||
from app.exception.err_exception import ErrException
|
||||
|
||||
|
||||
class PermissionMiddleware:
|
||||
def __init__(self, permissions: Iterable[str] | None = None) -> None:
|
||||
self.permissions = tuple(permissions or ())
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
request: Request,
|
||||
user_repository: AdminUserRepository = Depends(get_admin_user_repository),
|
||||
) -> None:
|
||||
admin_id = getattr(request.state, "current_admin_id", current_admin_id.get())
|
||||
if admin_id <= 0:
|
||||
raise ErrException("账户不存在")
|
||||
|
||||
admin_user = await user_repository.find_by_id(admin_id)
|
||||
if admin_user is None:
|
||||
raise ErrException("账户不存在")
|
||||
if admin_user.status == AdminUserStatusCode.DISABLE:
|
||||
raise ErrException("账号已禁用", AdminCode.DISABLED)
|
||||
|
||||
request.state.current_admin_user = admin_user
|
||||
if self.permissions and admin_user.user_type != "SuperAdmin":
|
||||
raise ErrException("暂无权限", AdminCode.FORBIDDEN)
|
||||
13
app/middleware/admin/refresh_admin_token_middleware.py
Normal file
13
app/middleware/admin/refresh_admin_token_middleware.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from fastapi import Request
|
||||
|
||||
from app.lib.jwt.jwt import Jwt
|
||||
from app.lib.jwt.token import JwtToken
|
||||
from app.middleware.token.abstract_token_middleware import AbstractTokenMiddleware
|
||||
|
||||
|
||||
class RefreshAdminTokenMiddleware(AbstractTokenMiddleware):
|
||||
async def parser_token(self, jwt: Jwt, raw_token: str) -> JwtToken:
|
||||
return await jwt.parser_refresh_token(raw_token)
|
||||
|
||||
def set_context(self, request: Request, token: JwtToken) -> None:
|
||||
return None
|
||||
1
app/middleware/token/__init__.py
Normal file
1
app/middleware/token/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Token middleware."""
|
||||
72
app/middleware/token/abstract_token_middleware.py
Normal file
72
app/middleware/token/abstract_token_middleware.py
Normal file
@@ -0,0 +1,72 @@
|
||||
from fastapi import Depends, Request
|
||||
|
||||
from app.constants.result_code import ResultCode
|
||||
from app.core.dependencies import get_token_service
|
||||
from app.exception.err_exception import ErrException
|
||||
from app.lib.jwt.exceptions import JwtError, JwtExpiredError
|
||||
from app.lib.jwt.jwt import Jwt
|
||||
from app.lib.jwt.token import JwtToken
|
||||
from app.service.base_token_service import BaseTokenService
|
||||
|
||||
|
||||
class AbstractTokenMiddleware:
|
||||
scene = "admin"
|
||||
|
||||
async def __call__(
|
||||
self,
|
||||
request: Request,
|
||||
token_service: BaseTokenService = Depends(get_token_service),
|
||||
) -> JwtToken:
|
||||
raw_token = self.get_token(request)
|
||||
jwt = token_service.get_jwt(self.scene)
|
||||
|
||||
try:
|
||||
token = await self.parser_token(jwt, raw_token)
|
||||
await token_service.check_jwt(jwt, token)
|
||||
self.check_issuer(jwt, token)
|
||||
except JwtExpiredError as exc:
|
||||
raise ErrException(
|
||||
"token过期",
|
||||
ResultCode.JWT_EXPIRED,
|
||||
{"err_msg": str(exc)},
|
||||
) from exc
|
||||
except JwtError as exc:
|
||||
raise ErrException(
|
||||
"token错误",
|
||||
ResultCode.JWT_ERROR,
|
||||
{"err_msg": str(exc)},
|
||||
) from exc
|
||||
|
||||
self.set_context(request, token)
|
||||
request.state.token = token
|
||||
return token
|
||||
|
||||
async def parser_token(self, jwt: Jwt, raw_token: str) -> JwtToken:
|
||||
raise NotImplementedError
|
||||
|
||||
def set_context(self, request: Request, token: JwtToken) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def check_issuer(jwt: Jwt, token: JwtToken) -> None:
|
||||
if token.claims.get("iss") != jwt.issuer:
|
||||
raise ErrException("token错误", ResultCode.JWT_ERROR)
|
||||
|
||||
@staticmethod
|
||||
def get_token(request: Request) -> str:
|
||||
authorization = request.headers.get("authorization")
|
||||
if authorization:
|
||||
scheme, _, token = authorization.partition(" ")
|
||||
if scheme.lower() == "bearer" and token:
|
||||
return token.strip()
|
||||
return authorization.replace("Bearer ", "", 1).strip()
|
||||
|
||||
token_header = request.headers.get("token")
|
||||
if token_header:
|
||||
return token_header.strip()
|
||||
|
||||
token_query = request.query_params.get("token")
|
||||
if token_query:
|
||||
return token_query.strip()
|
||||
|
||||
raise ErrException("token缺失", ResultCode.JWT_ERROR)
|
||||
1
app/model/__init__.py
Normal file
1
app/model/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Model layer."""
|
||||
53
app/model/admin_user.py
Normal file
53
app/model/admin_user.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from app.common.security.password_hasher import verify_password
|
||||
from app.constants.model.admin_user.admin_user_status_code import AdminUserStatusCode
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AdminUser:
|
||||
id: int
|
||||
username: str
|
||||
password: str
|
||||
user_type: str
|
||||
nickname: str
|
||||
phone: str
|
||||
email: str
|
||||
status: AdminUserStatusCode
|
||||
login_ip: str
|
||||
login_time: str
|
||||
remark: str
|
||||
|
||||
@classmethod
|
||||
def from_row(cls, row: dict[str, Any]) -> "AdminUser":
|
||||
return cls(
|
||||
id=int(row["id"]),
|
||||
username=str(row["username"]),
|
||||
password=str(row["password"]),
|
||||
user_type=str(row["user_type"]),
|
||||
nickname=str(row["nickname"]),
|
||||
phone=str(row["phone"]),
|
||||
email=str(row["email"]),
|
||||
status=AdminUserStatusCode(int(row["status"])),
|
||||
login_ip=str(row["login_ip"]),
|
||||
login_time=str(row["login_time"]),
|
||||
remark=str(row["remark"]),
|
||||
)
|
||||
|
||||
def verify_password(self, password: str) -> bool:
|
||||
return verify_password(password, self.password)
|
||||
|
||||
def to_public_dict(self) -> dict:
|
||||
return {
|
||||
"id": self.id,
|
||||
"username": self.username,
|
||||
"user_type": self.user_type,
|
||||
"nickname": self.nickname,
|
||||
"phone": self.phone,
|
||||
"email": self.email,
|
||||
"status": int(self.status),
|
||||
"login_ip": self.login_ip,
|
||||
"login_time": self.login_time,
|
||||
"remark": self.remark,
|
||||
}
|
||||
1
app/request/__init__.py
Normal file
1
app/request/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Request schemas."""
|
||||
1
app/request/admin/__init__.py
Normal file
1
app/request/admin/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Admin request schemas."""
|
||||
6
app/request/admin/login_request.py
Normal file
6
app/request/admin/login_request.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
username: str = Field(min_length=1)
|
||||
password: str = Field(min_length=1)
|
||||
1
app/service/__init__.py
Normal file
1
app/service/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Service layer."""
|
||||
1
app/service/admin/__init__.py
Normal file
1
app/service/admin/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Admin services."""
|
||||
15
app/service/admin/base_admin_service.py
Normal file
15
app/service/admin/base_admin_service.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from app.common.context import current_admin_id
|
||||
from app.exception.err_exception import ErrException
|
||||
from app.lib.response.admin_return import AdminReturn
|
||||
|
||||
|
||||
class BaseAdminService:
|
||||
def __init__(self, admin_return: AdminReturn) -> None:
|
||||
self.admin_return = admin_return
|
||||
|
||||
@property
|
||||
def admin_id(self) -> int:
|
||||
admin_id = current_admin_id.get()
|
||||
if admin_id <= 0:
|
||||
raise ErrException("账户不存在")
|
||||
return admin_id
|
||||
1
app/service/admin/login/__init__.py
Normal file
1
app/service/admin/login/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Admin login services."""
|
||||
56
app/service/admin/login/login_service.py
Normal file
56
app/service/admin/login/login_service.py
Normal file
@@ -0,0 +1,56 @@
|
||||
import asyncio
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from app.common.repository.admin_user_repository import AdminUserRepository
|
||||
from app.constants.model.admin_user.admin_user_status_code import AdminUserStatusCode
|
||||
from app.exception.err_exception import ErrException
|
||||
from app.lib.response.admin_return import AdminReturn
|
||||
from app.request.admin.login_request import LoginRequest
|
||||
from app.service.admin.base_admin_service import BaseAdminService
|
||||
from app.service.base_token_service import BaseTokenService
|
||||
|
||||
|
||||
class LoginService(BaseAdminService):
|
||||
def __init__(
|
||||
self,
|
||||
user_repository: AdminUserRepository,
|
||||
token_service: BaseTokenService,
|
||||
admin_return: AdminReturn,
|
||||
) -> None:
|
||||
super().__init__(admin_return)
|
||||
self.user_repository = user_repository
|
||||
self.token_service = token_service
|
||||
|
||||
async def handle(self, payload: LoginRequest, request: Request) -> dict:
|
||||
admin_info = await self.user_repository.find_by_username(payload.username)
|
||||
if admin_info is None:
|
||||
raise ErrException("后台管理员不存在")
|
||||
|
||||
password_valid = await asyncio.to_thread(
|
||||
admin_info.verify_password,
|
||||
payload.password,
|
||||
)
|
||||
if not password_valid:
|
||||
raise ErrException("密码错误")
|
||||
|
||||
if admin_info.status == AdminUserStatusCode.DISABLE:
|
||||
raise ErrException("用户已禁用")
|
||||
|
||||
await self.user_repository.record_login(admin_info.id, self._client_ip(request))
|
||||
jwt = self.token_service.get_jwt("admin")
|
||||
return self.admin_return.success(
|
||||
"success",
|
||||
{
|
||||
"access_token": jwt.builder_access_token(str(admin_info.id)),
|
||||
"refresh_token": jwt.builder_refresh_token(str(admin_info.id)),
|
||||
"expire_at": int(jwt.get_config("ttl", 0)),
|
||||
},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _client_ip(request: Request) -> str:
|
||||
forwarded = request.headers.get("x-forwarded-for")
|
||||
if forwarded:
|
||||
return forwarded.split(",", 1)[0].strip()
|
||||
return request.client.host if request.client else ""
|
||||
22
app/service/admin/login/refresh_service.py
Normal file
22
app/service/admin/login/refresh_service.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from app.lib.jwt.token import JwtToken
|
||||
from app.lib.response.admin_return import AdminReturn
|
||||
from app.service.admin.base_admin_service import BaseAdminService
|
||||
from app.service.base_token_service import BaseTokenService
|
||||
|
||||
|
||||
class RefreshService(BaseAdminService):
|
||||
def __init__(self, token_service: BaseTokenService, admin_return: AdminReturn) -> None:
|
||||
super().__init__(admin_return)
|
||||
self.token_service = token_service
|
||||
|
||||
async def handle(self, token: JwtToken) -> dict:
|
||||
return self.admin_return.success("success", await self.refresh_token(token))
|
||||
|
||||
async def refresh_token(self, token: JwtToken) -> dict[str, int | str]:
|
||||
jwt = self.token_service.get_jwt("admin")
|
||||
await jwt.add_blacklist(token)
|
||||
return {
|
||||
"access_token": jwt.builder_access_token(token.jwt_id),
|
||||
"refresh_token": jwt.builder_refresh_token(token.jwt_id),
|
||||
"expire_at": int(jwt.get_config("ttl", 0)),
|
||||
}
|
||||
1
app/service/admin/profile/__init__.py
Normal file
1
app/service/admin/profile/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Admin profile services."""
|
||||
20
app/service/admin/profile/current_user_service.py
Normal file
20
app/service/admin/profile/current_user_service.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from app.common.repository.admin_user_repository import AdminUserRepository
|
||||
from app.exception.err_exception import ErrException
|
||||
from app.lib.response.admin_return import AdminReturn
|
||||
from app.service.admin.base_admin_service import BaseAdminService
|
||||
|
||||
|
||||
class CurrentUserService(BaseAdminService):
|
||||
def __init__(
|
||||
self,
|
||||
user_repository: AdminUserRepository,
|
||||
admin_return: AdminReturn,
|
||||
) -> None:
|
||||
super().__init__(admin_return)
|
||||
self.user_repository = user_repository
|
||||
|
||||
async def handle(self) -> dict:
|
||||
admin_user = await self.user_repository.find_by_id(self.admin_id)
|
||||
if admin_user is None:
|
||||
raise ErrException("账户不存在")
|
||||
return self.admin_return.success("success", admin_user.to_public_dict())
|
||||
17
app/service/base_token_service.py
Normal file
17
app/service/base_token_service.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from app.constants.result_code import ResultCode
|
||||
from app.exception.err_exception import ErrException
|
||||
from app.lib.jwt.factory import JwtFactory
|
||||
from app.lib.jwt.jwt import Jwt
|
||||
from app.lib.jwt.token import JwtToken
|
||||
|
||||
|
||||
class BaseTokenService:
|
||||
def __init__(self, jwt_factory: JwtFactory) -> None:
|
||||
self.jwt_factory = jwt_factory
|
||||
|
||||
def get_jwt(self, scene: str = "admin") -> Jwt:
|
||||
return self.jwt_factory.get(scene)
|
||||
|
||||
async def check_jwt(self, jwt: Jwt, token: JwtToken) -> None:
|
||||
if await jwt.has_blacklist(token):
|
||||
raise ErrException("token已过期", ResultCode.JWT_EXPIRED)
|
||||
14
pyproject.toml
Normal file
14
pyproject.toml
Normal file
@@ -0,0 +1,14 @@
|
||||
[project]
|
||||
name = "py-server"
|
||||
version = "0.1.0"
|
||||
description = "Async API server scaffold inspired by the Hyperf app layering."
|
||||
requires-python = ">=3.14"
|
||||
dependencies = [
|
||||
"fastapi>=0.136.0",
|
||||
"pydantic-settings>=2.14.0",
|
||||
"uvicorn[standard]>=0.49.0",
|
||||
]
|
||||
|
||||
[tool.pyright]
|
||||
venvPath = "."
|
||||
venv = ".venv"
|
||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Test package."""
|
||||
88
tests/test_admin_login_flow.py
Normal file
88
tests/test_admin_login_flow.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
DB_PATH = Path(tempfile.gettempdir()) / "py_server_admin_login_test.db"
|
||||
DB_PATH.unlink(missing_ok=True)
|
||||
os.environ["DATABASE_PATH"] = str(DB_PATH)
|
||||
os.environ["JWT_ADMIN_SECRET"] = "test_admin_secret"
|
||||
os.environ["ADMIN_SEED_USERNAME"] = "admin"
|
||||
os.environ["ADMIN_SEED_PASSWORD"] = "admin"
|
||||
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from app.core.dependencies import bootstrap_database
|
||||
from app.main import app
|
||||
|
||||
|
||||
class AdminLoginFlowTest(unittest.IsolatedAsyncioTestCase):
|
||||
async def asyncSetUp(self) -> None:
|
||||
await bootstrap_database()
|
||||
self.client = AsyncClient(
|
||||
transport=ASGITransport(app=app),
|
||||
base_url="http://testserver",
|
||||
)
|
||||
|
||||
async def asyncTearDown(self) -> None:
|
||||
await self.client.aclose()
|
||||
|
||||
async def test_login_access_and_refresh_flow(self) -> None:
|
||||
login_response = await self.client.post(
|
||||
"/admin/login/login",
|
||||
json={"username": "admin", "password": "admin"},
|
||||
)
|
||||
login_payload = login_response.json()
|
||||
|
||||
self.assertEqual(login_payload["code"], 0)
|
||||
self.assertIn("access_token", login_payload["data"])
|
||||
self.assertIn("refresh_token", login_payload["data"])
|
||||
self.assertEqual(login_payload["data"]["expire_at"], 3600)
|
||||
|
||||
access_token = login_payload["data"]["access_token"]
|
||||
refresh_token = login_payload["data"]["refresh_token"]
|
||||
|
||||
current_response = await self.client.get(
|
||||
"/admin/profile/current",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
current_payload = current_response.json()
|
||||
|
||||
self.assertEqual(current_payload["code"], 0)
|
||||
self.assertEqual(current_payload["data"]["username"], "admin")
|
||||
|
||||
refresh_response = await self.client.post(
|
||||
"/admin/login/refresh",
|
||||
headers={"Authorization": f"Bearer {refresh_token}"},
|
||||
)
|
||||
refresh_payload = refresh_response.json()
|
||||
|
||||
self.assertEqual(refresh_payload["code"], 0)
|
||||
self.assertNotEqual(refresh_payload["data"]["refresh_token"], refresh_token)
|
||||
|
||||
reused_response = await self.client.post(
|
||||
"/admin/login/refresh",
|
||||
headers={"Authorization": f"Bearer {refresh_token}"},
|
||||
)
|
||||
reused_payload = reused_response.json()
|
||||
|
||||
self.assertEqual(reused_payload["code"], 10001)
|
||||
|
||||
async def test_refresh_endpoint_rejects_access_token(self) -> None:
|
||||
login_response = await self.client.post(
|
||||
"/admin/login/login",
|
||||
json={"username": "admin", "password": "admin"},
|
||||
)
|
||||
access_token = login_response.json()["data"]["access_token"]
|
||||
|
||||
refresh_response = await self.client.post(
|
||||
"/admin/login/refresh",
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
refresh_payload = refresh_response.json()
|
||||
|
||||
self.assertEqual(refresh_payload["code"], 10002)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user