Initial FastAPI admin auth scaffold
This commit is contained in:
11
.env.example
Normal file
11
.env.example
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
APP_NAME=py_server
|
||||||
|
APP_ENV=local
|
||||||
|
DATABASE_PATH=storage/app.db
|
||||||
|
|
||||||
|
JWT_ADMIN_SECRET=change_me_admin_secret
|
||||||
|
ADMIN_JWT_TTL=3600
|
||||||
|
ADMIN_JWT_REFRESH_TTL=7200
|
||||||
|
JWT_BLACKLIST_TTL=7201
|
||||||
|
|
||||||
|
ADMIN_SEED_USERNAME=admin
|
||||||
|
ADMIN_SEED_PASSWORD=admin
|
||||||
7
.gitignore
vendored
Normal file
7
.gitignore
vendored
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
.env
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
.pytest_cache/
|
||||||
|
.ruff_cache/
|
||||||
|
.idea/
|
||||||
|
storage/*.db
|
||||||
145
README.md
Normal file
145
README.md
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
# py_server
|
||||||
|
|
||||||
|
这是一个 FastAPI 异步 API 项目骨架,分层方式参考了 Hyperf 项目里的 `server/app`。
|
||||||
|
|
||||||
|
## 分层对应
|
||||||
|
|
||||||
|
| Hyperf | FastAPI 当前项目 | 作用 |
|
||||||
|
| --- | --- | --- |
|
||||||
|
| `Controller/Admin` | `app/controller/admin` | 接收请求,调用 service |
|
||||||
|
| `Service/Admin` | `app/service/admin` | 业务逻辑 |
|
||||||
|
| `Common/Repository` | `app/common/repository` | 数据访问 |
|
||||||
|
| `Model` | `app/model` | 数据模型 |
|
||||||
|
| `Middleware/Admin` | `app/middleware/admin` | admin token 和权限校验 |
|
||||||
|
| `Lib/Jwt` | `app/lib/jwt` | JWT 签发、解析、黑名单 |
|
||||||
|
| `Request/Admin` | `app/request/admin` | 请求参数校验 |
|
||||||
|
| `Constants` | `app/constants` | 状态码和业务常量 |
|
||||||
|
| `Lib/Return` | `app/lib/response` | 统一返回结构 |
|
||||||
|
|
||||||
|
> Python 里 `return` 是关键字,所以参考 Hyperf 的 `Lib/Return` 在这里命名为 `lib/response`。
|
||||||
|
|
||||||
|
## 启动
|
||||||
|
|
||||||
|
```bash
|
||||||
|
.venv/bin/uvicorn app.main:app --reload
|
||||||
|
```
|
||||||
|
|
||||||
|
默认启动时会自动创建 SQLite 开发库,并创建一个 admin 用户:
|
||||||
|
|
||||||
|
- username: `admin`
|
||||||
|
- password: `admin`
|
||||||
|
|
||||||
|
本地配置可以复制:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cp .env.example .env
|
||||||
|
```
|
||||||
|
|
||||||
|
## Admin 登录接口
|
||||||
|
|
||||||
|
### 登录
|
||||||
|
|
||||||
|
```http
|
||||||
|
POST /admin/login/login
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"username": "admin",
|
||||||
|
"password": "admin"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
返回:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"code": 0,
|
||||||
|
"message": "success",
|
||||||
|
"data": {
|
||||||
|
"access_token": "...",
|
||||||
|
"refresh_token": "...",
|
||||||
|
"expire_at": 3600
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 刷新 token
|
||||||
|
|
||||||
|
```http
|
||||||
|
POST /admin/login/refresh
|
||||||
|
Authorization: Bearer <refresh_token>
|
||||||
|
```
|
||||||
|
|
||||||
|
刷新成功后会签发新的 `access_token` 和 `refresh_token`,并把旧的 `refresh_token` 加入黑名单,防止重复刷新。
|
||||||
|
|
||||||
|
### 当前 admin 用户
|
||||||
|
|
||||||
|
```http
|
||||||
|
GET /admin/profile/current
|
||||||
|
Authorization: Bearer <access_token>
|
||||||
|
```
|
||||||
|
|
||||||
|
## JWT 逻辑
|
||||||
|
|
||||||
|
当前实现和参考 Hyperf 项目保持同样思路:
|
||||||
|
|
||||||
|
1. 登录成功后同时签发 `access_token` 和 `refresh_token`。
|
||||||
|
2. 普通 admin 接口只接受 `access_token`。
|
||||||
|
3. `/admin/login/refresh` 只接受 `refresh_token`。
|
||||||
|
4. refresh 成功后,把旧 refresh token 加入黑名单。
|
||||||
|
5. token 会校验签名、过期时间、issuer 和 token 类型。
|
||||||
|
|
||||||
|
token 读取顺序和 Hyperf 中间件一致:
|
||||||
|
|
||||||
|
1. `Authorization: Bearer <token>`
|
||||||
|
2. `token` header
|
||||||
|
3. query string: `?token=...`
|
||||||
|
|
||||||
|
## FastAPI 依赖注入
|
||||||
|
|
||||||
|
Controller 里这段:
|
||||||
|
|
||||||
|
```python
|
||||||
|
service: LoginService = Depends(get_login_service)
|
||||||
|
```
|
||||||
|
|
||||||
|
可以理解为 FastAPI 版的 Hyperf 容器注入。
|
||||||
|
|
||||||
|
请求进来时,FastAPI 会先调用 `get_login_service()`,把创建好的 `LoginService` 传给 `service` 参数。
|
||||||
|
|
||||||
|
`get_login_service()` 内部会继续组装:
|
||||||
|
|
||||||
|
- `AdminUserRepository`
|
||||||
|
- `BaseTokenService`
|
||||||
|
- `AdminReturn`
|
||||||
|
|
||||||
|
所以 controller 只负责接收请求和调用 service。
|
||||||
|
|
||||||
|
## `@lru_cache` 的作用
|
||||||
|
|
||||||
|
`app/core/dependencies.py` 里的:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@lru_cache
|
||||||
|
def get_database() -> Database:
|
||||||
|
return Database(get_settings().database_path)
|
||||||
|
```
|
||||||
|
|
||||||
|
表示第一次调用时创建对象,后面再次调用时直接复用第一次创建的对象。
|
||||||
|
|
||||||
|
在这个项目里,它的作用类似 Hyperf 容器里的共享服务/单例服务。比如 JWT 黑名单必须复用同一个对象,否则旧 refresh token 加入黑名单后,下一次请求就查不到了。
|
||||||
|
|
||||||
|
## 测试
|
||||||
|
|
||||||
|
```bash
|
||||||
|
.venv/bin/python -m compileall app tests
|
||||||
|
.venv/bin/python -m unittest tests.test_admin_login_flow
|
||||||
|
```
|
||||||
|
|
||||||
|
测试覆盖:
|
||||||
|
|
||||||
|
- admin 登录成功
|
||||||
|
- access token 访问当前用户成功
|
||||||
|
- refresh token 换新 token 成功
|
||||||
|
- 旧 refresh token 复用失败
|
||||||
|
- access token 调用 refresh 接口失败
|
||||||
1
app/__init__.py
Normal file
1
app/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Application package."""
|
||||||
1
app/common/__init__.py
Normal file
1
app/common/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Common helpers."""
|
||||||
3
app/common/context.py
Normal file
3
app/common/context.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from contextvars import ContextVar
|
||||||
|
|
||||||
|
current_admin_id: ContextVar[int] = ContextVar("current_admin_id", default=0)
|
||||||
1
app/common/repository/__init__.py
Normal file
1
app/common/repository/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Repository layer."""
|
||||||
79
app/common/repository/admin_user_repository.py
Normal file
79
app/common/repository/admin_user_repository.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
from app.common.repository.base_repository import BaseRepository
|
||||||
|
from app.common.security.password_hasher import hash_password
|
||||||
|
from app.constants.model.admin_user.admin_user_status_code import AdminUserStatusCode
|
||||||
|
from app.model.admin_user import AdminUser
|
||||||
|
|
||||||
|
|
||||||
|
class AdminUserRepository(BaseRepository):
|
||||||
|
async def find_by_username(self, username: str) -> AdminUser | None:
|
||||||
|
row = await self.database.fetchone(
|
||||||
|
"""
|
||||||
|
SELECT id, username, password, user_type, nickname, phone, email,
|
||||||
|
status, login_ip, login_time, remark
|
||||||
|
FROM admin_user
|
||||||
|
WHERE username = ?
|
||||||
|
LIMIT 1
|
||||||
|
""",
|
||||||
|
(username,),
|
||||||
|
)
|
||||||
|
return AdminUser.from_row(row) if row else None
|
||||||
|
|
||||||
|
async def find_by_id(self, admin_id: int) -> AdminUser | None:
|
||||||
|
row = await self.database.fetchone(
|
||||||
|
"""
|
||||||
|
SELECT id, username, password, user_type, nickname, phone, email,
|
||||||
|
status, login_ip, login_time, remark
|
||||||
|
FROM admin_user
|
||||||
|
WHERE id = ?
|
||||||
|
LIMIT 1
|
||||||
|
""",
|
||||||
|
(admin_id,),
|
||||||
|
)
|
||||||
|
return AdminUser.from_row(row) if row else None
|
||||||
|
|
||||||
|
async def record_login(self, admin_id: int, login_ip: str) -> None:
|
||||||
|
now = self._now()
|
||||||
|
await self.database.execute(
|
||||||
|
"""
|
||||||
|
UPDATE admin_user
|
||||||
|
SET login_ip = ?, login_time = ?, updated_at = ?
|
||||||
|
WHERE id = ?
|
||||||
|
""",
|
||||||
|
(login_ip, now, now, admin_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def ensure_seed_admin(self, username: str, password: str) -> None:
|
||||||
|
exists = await self.find_by_username(username)
|
||||||
|
if exists is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
now = self._now()
|
||||||
|
await self.database.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO admin_user (
|
||||||
|
username, password, user_type, nickname, phone, email, status,
|
||||||
|
login_ip, login_time, created_at, updated_at, remark
|
||||||
|
)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
(
|
||||||
|
username,
|
||||||
|
hash_password(password),
|
||||||
|
"SuperAdmin",
|
||||||
|
"Super Admin",
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
int(AdminUserStatusCode.NORMAL),
|
||||||
|
"",
|
||||||
|
"",
|
||||||
|
now,
|
||||||
|
now,
|
||||||
|
"seeded by application startup",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _now() -> str:
|
||||||
|
return datetime.now(UTC).strftime("%Y-%m-%d %H:%M:%S")
|
||||||
6
app/common/repository/base_repository.py
Normal file
6
app/common/repository/base_repository.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from app.core.database import Database
|
||||||
|
|
||||||
|
|
||||||
|
class BaseRepository:
|
||||||
|
def __init__(self, database: Database) -> None:
|
||||||
|
self.database = database
|
||||||
1
app/common/security/__init__.py
Normal file
1
app/common/security/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Security helpers."""
|
||||||
40
app/common/security/password_hasher.py
Normal file
40
app/common/security/password_hasher.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
ALGORITHM = "pbkdf2_sha256"
|
||||||
|
ITERATIONS = 390_000
|
||||||
|
|
||||||
|
|
||||||
|
def _b64encode(raw: bytes) -> str:
|
||||||
|
return base64.urlsafe_b64encode(raw).decode("ascii").rstrip("=")
|
||||||
|
|
||||||
|
|
||||||
|
def _b64decode(value: str) -> bytes:
|
||||||
|
padding = "=" * (-len(value) % 4)
|
||||||
|
return base64.urlsafe_b64decode(value + padding)
|
||||||
|
|
||||||
|
|
||||||
|
def hash_password(password: str, iterations: int = ITERATIONS) -> str:
|
||||||
|
salt = secrets.token_bytes(16)
|
||||||
|
digest = hashlib.pbkdf2_hmac("sha256", password.encode("utf-8"), salt, iterations)
|
||||||
|
return f"{ALGORITHM}${iterations}${_b64encode(salt)}${_b64encode(digest)}"
|
||||||
|
|
||||||
|
|
||||||
|
def verify_password(password: str, stored_hash: str) -> bool:
|
||||||
|
try:
|
||||||
|
algorithm, iterations, salt, expected = stored_hash.split("$", 3)
|
||||||
|
except ValueError:
|
||||||
|
return hmac.compare_digest(password, stored_hash)
|
||||||
|
|
||||||
|
if algorithm != ALGORITHM:
|
||||||
|
return False
|
||||||
|
|
||||||
|
digest = hashlib.pbkdf2_hmac(
|
||||||
|
"sha256",
|
||||||
|
password.encode("utf-8"),
|
||||||
|
_b64decode(salt),
|
||||||
|
int(iterations),
|
||||||
|
)
|
||||||
|
return hmac.compare_digest(_b64encode(digest), expected)
|
||||||
1
app/constants/__init__.py
Normal file
1
app/constants/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Application constants."""
|
||||||
6
app/constants/admin_code.py
Normal file
6
app/constants/admin_code.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from enum import IntEnum
|
||||||
|
|
||||||
|
|
||||||
|
class AdminCode(IntEnum):
|
||||||
|
DISABLED = 30001
|
||||||
|
FORBIDDEN = 30002
|
||||||
1
app/constants/model/__init__.py
Normal file
1
app/constants/model/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Model constants."""
|
||||||
1
app/constants/model/admin_user/__init__.py
Normal file
1
app/constants/model/admin_user/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Admin user constants."""
|
||||||
12
app/constants/model/admin_user/admin_user_status_code.py
Normal file
12
app/constants/model/admin_user/admin_user_status_code.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from enum import IntEnum
|
||||||
|
|
||||||
|
|
||||||
|
class AdminUserStatusCode(IntEnum):
|
||||||
|
NORMAL = 1
|
||||||
|
DISABLE = 2
|
||||||
|
|
||||||
|
def is_normal(self) -> bool:
|
||||||
|
return self is AdminUserStatusCode.NORMAL
|
||||||
|
|
||||||
|
def is_disable(self) -> bool:
|
||||||
|
return self is AdminUserStatusCode.DISABLE
|
||||||
12
app/constants/result_code.py
Normal file
12
app/constants/result_code.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
from enum import IntEnum
|
||||||
|
|
||||||
|
|
||||||
|
class ResultCode(IntEnum):
|
||||||
|
SUCCESS = 0
|
||||||
|
ERROR = 1
|
||||||
|
JWT_EXPIRED = 10001
|
||||||
|
JWT_ERROR = 10002
|
||||||
|
OLD_PASSWORD_ERROR = 10003
|
||||||
|
ACCOUNT_DEACTIVATING = 20001
|
||||||
|
ACCOUNT_DEACTIVATED = 20002
|
||||||
|
ACCOUNT_CANNOT_DEACTIVATE = 20003
|
||||||
1
app/controller/__init__.py
Normal file
1
app/controller/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Controller layer."""
|
||||||
1
app/controller/admin/__init__.py
Normal file
1
app/controller/admin/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Admin controllers."""
|
||||||
32
app/controller/admin/login_controller.py
Normal file
32
app/controller/admin/login_controller.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
from fastapi import APIRouter, Depends, Request
|
||||||
|
|
||||||
|
from app.core.dependencies import get_login_service, get_refresh_service
|
||||||
|
from app.lib.jwt.token import JwtToken
|
||||||
|
from app.middleware.admin.refresh_admin_token_middleware import RefreshAdminTokenMiddleware
|
||||||
|
from app.request.admin.login_request import LoginRequest
|
||||||
|
from app.service.admin.login.login_service import LoginService
|
||||||
|
from app.service.admin.login.refresh_service import RefreshService
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/admin/login", tags=["admin-login"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login")
|
||||||
|
async def login(
|
||||||
|
payload: LoginRequest,
|
||||||
|
request: Request,
|
||||||
|
# FastAPI 的 Depends 类似 Hyperf 里的容器注入。
|
||||||
|
# 请求进入这个接口时,框架会先调用 get_login_service(),
|
||||||
|
# 把组装好的 LoginService 传进 service 参数。
|
||||||
|
service: LoginService = Depends(get_login_service),
|
||||||
|
) -> dict:
|
||||||
|
return await service.handle(payload, request)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/refresh")
|
||||||
|
async def refresh(
|
||||||
|
# 这里把 RefreshAdminTokenMiddleware 当成依赖使用。
|
||||||
|
# FastAPI 会先执行 refresh token 校验,通过后才进入 controller。
|
||||||
|
token: JwtToken = Depends(RefreshAdminTokenMiddleware()),
|
||||||
|
service: RefreshService = Depends(get_refresh_service),
|
||||||
|
) -> dict:
|
||||||
|
return await service.handle(token)
|
||||||
21
app/controller/admin/profile_controller.py
Normal file
21
app/controller/admin/profile_controller.py
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
|
||||||
|
from app.core.dependencies import get_current_user_service
|
||||||
|
from app.middleware.admin.admin_token_middleware import AdminTokenMiddleware
|
||||||
|
from app.middleware.admin.permission_middleware import PermissionMiddleware
|
||||||
|
from app.service.admin.profile.current_user_service import CurrentUserService
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/admin/profile", tags=["admin-profile"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/current",
|
||||||
|
dependencies=[
|
||||||
|
Depends(AdminTokenMiddleware()),
|
||||||
|
Depends(PermissionMiddleware()),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def current(
|
||||||
|
service: CurrentUserService = Depends(get_current_user_service),
|
||||||
|
) -> dict:
|
||||||
|
return await service.handle()
|
||||||
1
app/controller/api/__init__.py
Normal file
1
app/controller/api/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Frontend API controllers."""
|
||||||
10
app/controller/api/health_controller.py
Normal file
10
app/controller/api/health_controller.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from app.lib.response.admin_return import AdminReturn
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api", tags=["api"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/health")
|
||||||
|
async def health() -> dict:
|
||||||
|
return AdminReturn().success("success", {"status": "ok"})
|
||||||
1
app/core/__init__.py
Normal file
1
app/core/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Core application wiring."""
|
||||||
39
app/core/config.py
Normal file
39
app/core/config.py
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
model_config = SettingsConfigDict(
|
||||||
|
env_file=".env",
|
||||||
|
env_file_encoding="utf-8",
|
||||||
|
extra="ignore",
|
||||||
|
populate_by_name=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
app_name: str = Field(default="py_server", alias="APP_NAME")
|
||||||
|
app_env: str = Field(default="local", alias="APP_ENV")
|
||||||
|
database_path: Path = Field(default=Path("storage/app.db"), alias="DATABASE_PATH")
|
||||||
|
|
||||||
|
admin_jwt_secret: str = Field(
|
||||||
|
default="dev_admin_secret_change_me",
|
||||||
|
alias="JWT_ADMIN_SECRET",
|
||||||
|
)
|
||||||
|
admin_jwt_ttl: int = Field(default=3600, alias="ADMIN_JWT_TTL")
|
||||||
|
admin_jwt_refresh_ttl: int = Field(default=7200, alias="ADMIN_JWT_REFRESH_TTL")
|
||||||
|
jwt_blacklist_ttl: int = Field(default=7201, alias="JWT_BLACKLIST_TTL")
|
||||||
|
|
||||||
|
api_jwt_secret: str = Field(default="dev_api_secret_change_me", alias="JWT_SECRET")
|
||||||
|
api_jwt_ttl: int = Field(default=3600, alias="JWT_TTL")
|
||||||
|
api_jwt_refresh_ttl: int = Field(default=7200, alias="JWT_REFRESH_TTL")
|
||||||
|
|
||||||
|
admin_seed_username: str = Field(default="admin", alias="ADMIN_SEED_USERNAME")
|
||||||
|
admin_seed_password: str = Field(default="admin", alias="ADMIN_SEED_PASSWORD")
|
||||||
|
cors_allow_origins: list[str] = Field(default=["*"], alias="CORS_ALLOW_ORIGINS")
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_settings() -> Settings:
|
||||||
|
return Settings()
|
||||||
87
app/core/database.py
Normal file
87
app/core/database.py
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
import asyncio
|
||||||
|
import sqlite3
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class Database:
|
||||||
|
def __init__(self, path: Path) -> None:
|
||||||
|
self.path = path
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
await asyncio.to_thread(self._initialize)
|
||||||
|
|
||||||
|
async def execute(self, sql: str, params: tuple[Any, ...] = ()) -> int:
|
||||||
|
return await asyncio.to_thread(self._execute, sql, params)
|
||||||
|
|
||||||
|
async def fetchone(
|
||||||
|
self,
|
||||||
|
sql: str,
|
||||||
|
params: tuple[Any, ...] = (),
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
return await asyncio.to_thread(self._fetchone, sql, params)
|
||||||
|
|
||||||
|
async def fetchall(
|
||||||
|
self,
|
||||||
|
sql: str,
|
||||||
|
params: tuple[Any, ...] = (),
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
return await asyncio.to_thread(self._fetchall, sql, params)
|
||||||
|
|
||||||
|
def _connect(self) -> sqlite3.Connection:
|
||||||
|
self.path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
conn = sqlite3.connect(self.path)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
conn.execute("PRAGMA foreign_keys = ON")
|
||||||
|
return conn
|
||||||
|
|
||||||
|
def _initialize(self) -> None:
|
||||||
|
conn = self._connect()
|
||||||
|
try:
|
||||||
|
conn.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS admin_user (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
username TEXT NOT NULL UNIQUE,
|
||||||
|
password TEXT NOT NULL,
|
||||||
|
user_type TEXT NOT NULL DEFAULT 'admin',
|
||||||
|
nickname TEXT NOT NULL DEFAULT '',
|
||||||
|
phone TEXT NOT NULL DEFAULT '',
|
||||||
|
email TEXT NOT NULL DEFAULT '',
|
||||||
|
status INTEGER NOT NULL DEFAULT 1,
|
||||||
|
login_ip TEXT NOT NULL DEFAULT '',
|
||||||
|
login_time TEXT NOT NULL DEFAULT '',
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
updated_at TEXT NOT NULL,
|
||||||
|
remark TEXT NOT NULL DEFAULT ''
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def _execute(self, sql: str, params: tuple[Any, ...]) -> int:
|
||||||
|
conn = self._connect()
|
||||||
|
try:
|
||||||
|
cursor = conn.execute(sql, params)
|
||||||
|
conn.commit()
|
||||||
|
return int(cursor.lastrowid or 0)
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def _fetchone(self, sql: str, params: tuple[Any, ...]) -> dict[str, Any] | None:
|
||||||
|
conn = self._connect()
|
||||||
|
try:
|
||||||
|
row = conn.execute(sql, params).fetchone()
|
||||||
|
return dict(row) if row else None
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def _fetchall(self, sql: str, params: tuple[Any, ...]) -> list[dict[str, Any]]:
|
||||||
|
conn = self._connect()
|
||||||
|
try:
|
||||||
|
rows = conn.execute(sql, params).fetchall()
|
||||||
|
return [dict(row) for row in rows]
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
70
app/core/dependencies.py
Normal file
70
app/core/dependencies.py
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
from functools import lru_cache
|
||||||
|
|
||||||
|
from app.common.repository.admin_user_repository import AdminUserRepository
|
||||||
|
from app.core.config import Settings, get_settings
|
||||||
|
from app.core.database import Database
|
||||||
|
from app.lib.jwt.blacklist import InMemoryTokenBlacklist
|
||||||
|
from app.lib.jwt.factory import JwtFactory
|
||||||
|
from app.lib.response.admin_return import AdminReturn
|
||||||
|
from app.service.admin.login.login_service import LoginService
|
||||||
|
from app.service.admin.login.refresh_service import RefreshService
|
||||||
|
from app.service.admin.profile.current_user_service import CurrentUserService
|
||||||
|
from app.service.base_token_service import BaseTokenService
|
||||||
|
|
||||||
|
|
||||||
|
# lru_cache 会缓存函数第一次创建出来的对象。
|
||||||
|
# 这里用它把 Database、JwtFactory、TokenService 等依赖做成应用级单例,
|
||||||
|
# 类似 Hyperf 从容器里反复 get 同一个共享服务。
|
||||||
|
@lru_cache
|
||||||
|
def get_database() -> Database:
|
||||||
|
return Database(get_settings().database_path)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_token_blacklist() -> InMemoryTokenBlacklist:
|
||||||
|
return InMemoryTokenBlacklist()
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_jwt_factory() -> JwtFactory:
|
||||||
|
return JwtFactory(get_settings(), get_token_blacklist())
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_token_service() -> BaseTokenService:
|
||||||
|
return BaseTokenService(get_jwt_factory())
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_admin_return() -> AdminReturn:
|
||||||
|
return AdminReturn()
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache
|
||||||
|
def get_admin_user_repository() -> AdminUserRepository:
|
||||||
|
return AdminUserRepository(get_database())
|
||||||
|
|
||||||
|
|
||||||
|
def get_login_service() -> LoginService:
|
||||||
|
return LoginService(
|
||||||
|
get_admin_user_repository(),
|
||||||
|
get_token_service(),
|
||||||
|
get_admin_return(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_refresh_service() -> RefreshService:
|
||||||
|
return RefreshService(get_token_service(), get_admin_return())
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user_service() -> CurrentUserService:
|
||||||
|
return CurrentUserService(get_admin_user_repository(), get_admin_return())
|
||||||
|
|
||||||
|
|
||||||
|
async def bootstrap_database(settings: Settings | None = None) -> None:
|
||||||
|
settings = settings or get_settings()
|
||||||
|
await get_database().initialize()
|
||||||
|
await get_admin_user_repository().ensure_seed_admin(
|
||||||
|
settings.admin_seed_username,
|
||||||
|
settings.admin_seed_password,
|
||||||
|
)
|
||||||
1
app/exception/__init__.py
Normal file
1
app/exception/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Exception layer."""
|
||||||
14
app/exception/err_exception.py
Normal file
14
app/exception/err_exception.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
from app.constants.result_code import ResultCode
|
||||||
|
|
||||||
|
|
||||||
|
class ErrException(Exception):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message: str = "failed",
|
||||||
|
code: int | ResultCode = ResultCode.ERROR,
|
||||||
|
data: dict | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(message)
|
||||||
|
self.message = message
|
||||||
|
self.code = int(code)
|
||||||
|
self.data = data or {}
|
||||||
34
app/exception/handler.py
Normal file
34
app/exception/handler.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
from fastapi import FastAPI, Request
|
||||||
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from app.constants.result_code import ResultCode
|
||||||
|
from app.exception.err_exception import ErrException
|
||||||
|
from app.lib.response.admin_return import AdminReturn
|
||||||
|
|
||||||
|
|
||||||
|
async def err_exception_handler(request: Request, exc: ErrException) -> JSONResponse:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=200,
|
||||||
|
content=AdminReturn().error(exc.message, exc.code, exc.data),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def validation_exception_handler(
|
||||||
|
request: Request,
|
||||||
|
exc: RequestValidationError,
|
||||||
|
) -> JSONResponse:
|
||||||
|
first_error = exc.errors()[0] if exc.errors() else {}
|
||||||
|
location = ".".join(str(item) for item in first_error.get("loc", []))
|
||||||
|
message = first_error.get("msg", "参数错误")
|
||||||
|
if location:
|
||||||
|
message = f"{location}: {message}"
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=200,
|
||||||
|
content=AdminReturn().error(message, ResultCode.ERROR),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def register_exception_handlers(app: FastAPI) -> None:
|
||||||
|
app.add_exception_handler(ErrException, err_exception_handler)
|
||||||
|
app.add_exception_handler(RequestValidationError, validation_exception_handler)
|
||||||
1
app/lib/__init__.py
Normal file
1
app/lib/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Library layer."""
|
||||||
1
app/lib/jwt/__init__.py
Normal file
1
app/lib/jwt/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""JWT helpers."""
|
||||||
25
app/lib/jwt/blacklist.py
Normal file
25
app/lib/jwt/blacklist.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
import asyncio
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryTokenBlacklist:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._items: dict[str, float] = {}
|
||||||
|
self._lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def add(self, token: str, ttl: int) -> None:
|
||||||
|
now = time.time()
|
||||||
|
async with self._lock:
|
||||||
|
self._cleanup(now)
|
||||||
|
self._items[token] = now + ttl
|
||||||
|
|
||||||
|
async def has(self, token: str) -> bool:
|
||||||
|
now = time.time()
|
||||||
|
async with self._lock:
|
||||||
|
self._cleanup(now)
|
||||||
|
return token in self._items
|
||||||
|
|
||||||
|
def _cleanup(self, now: float) -> None:
|
||||||
|
expired = [token for token, expires_at in self._items.items() if expires_at <= now]
|
||||||
|
for token in expired:
|
||||||
|
self._items.pop(token, None)
|
||||||
10
app/lib/jwt/exceptions.py
Normal file
10
app/lib/jwt/exceptions.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
class JwtError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class JwtExpiredError(JwtError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class JwtInvalidError(JwtError):
|
||||||
|
pass
|
||||||
34
app/lib/jwt/factory.py
Normal file
34
app/lib/jwt/factory.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
from app.core.config import Settings
|
||||||
|
from app.lib.jwt.blacklist import InMemoryTokenBlacklist
|
||||||
|
from app.lib.jwt.jwt import Jwt, JwtSceneConfig
|
||||||
|
|
||||||
|
|
||||||
|
class JwtFactory:
|
||||||
|
def __init__(self, settings: Settings, blacklist: InMemoryTokenBlacklist) -> None:
|
||||||
|
self.settings = settings
|
||||||
|
self.blacklist = blacklist
|
||||||
|
|
||||||
|
def get(self, scene: str = "default") -> Jwt:
|
||||||
|
if scene == "admin":
|
||||||
|
return Jwt(
|
||||||
|
JwtSceneConfig(
|
||||||
|
secret=self.settings.admin_jwt_secret,
|
||||||
|
issuer=f"{self.settings.app_name}_admin",
|
||||||
|
ttl=self.settings.admin_jwt_ttl,
|
||||||
|
refresh_ttl=self.settings.admin_jwt_refresh_ttl,
|
||||||
|
blacklist_ttl=self.settings.jwt_blacklist_ttl,
|
||||||
|
),
|
||||||
|
self.blacklist,
|
||||||
|
)
|
||||||
|
|
||||||
|
scene_name = "default" if scene == "default" else scene
|
||||||
|
return Jwt(
|
||||||
|
JwtSceneConfig(
|
||||||
|
secret=self.settings.api_jwt_secret,
|
||||||
|
issuer=f"{self.settings.app_name}_{scene_name}",
|
||||||
|
ttl=self.settings.api_jwt_ttl,
|
||||||
|
refresh_ttl=self.settings.api_jwt_refresh_ttl,
|
||||||
|
blacklist_ttl=self.settings.jwt_blacklist_ttl,
|
||||||
|
),
|
||||||
|
self.blacklist,
|
||||||
|
)
|
||||||
151
app/lib/jwt/jwt.py
Normal file
151
app/lib/jwt/jwt.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import json
|
||||||
|
import secrets
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.lib.jwt.blacklist import InMemoryTokenBlacklist
|
||||||
|
from app.lib.jwt.exceptions import JwtExpiredError, JwtInvalidError
|
||||||
|
from app.lib.jwt.token import JwtToken
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class JwtSceneConfig:
|
||||||
|
secret: str
|
||||||
|
issuer: str
|
||||||
|
ttl: int
|
||||||
|
refresh_ttl: int
|
||||||
|
blacklist_ttl: int
|
||||||
|
|
||||||
|
|
||||||
|
class Jwt:
|
||||||
|
def __init__(self, config: JwtSceneConfig, blacklist: InMemoryTokenBlacklist) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.blacklist = blacklist
|
||||||
|
|
||||||
|
@property
|
||||||
|
def issuer(self) -> str:
|
||||||
|
return self.config.issuer
|
||||||
|
|
||||||
|
def builder_access_token(self, admin_id: str) -> str:
|
||||||
|
return self._build_token(admin_id, token_type="access", ttl=self.config.ttl)
|
||||||
|
|
||||||
|
def builder_refresh_token(self, admin_id: str) -> str:
|
||||||
|
return self._build_token(admin_id, token_type="refresh", ttl=self.config.refresh_ttl)
|
||||||
|
|
||||||
|
async def parser_access_token(self, raw_token: str) -> JwtToken:
|
||||||
|
token = self._parse(raw_token)
|
||||||
|
if token.is_refresh:
|
||||||
|
raise JwtInvalidError("Token is a refresh token")
|
||||||
|
self._validate_issuer(token)
|
||||||
|
return token
|
||||||
|
|
||||||
|
async def parser_refresh_token(self, raw_token: str) -> JwtToken:
|
||||||
|
token = self._parse(raw_token)
|
||||||
|
if not token.is_refresh:
|
||||||
|
raise JwtInvalidError("Token is not a refresh token")
|
||||||
|
self._validate_issuer(token)
|
||||||
|
return token
|
||||||
|
|
||||||
|
async def add_blacklist(self, token: JwtToken) -> None:
|
||||||
|
await self.blacklist.add(token.raw, self.config.blacklist_ttl)
|
||||||
|
|
||||||
|
async def has_blacklist(self, token: JwtToken) -> bool:
|
||||||
|
return await self.blacklist.has(token.raw)
|
||||||
|
|
||||||
|
def get_config(self, key: str, default: Any = None) -> Any:
|
||||||
|
return getattr(self.config, key, default)
|
||||||
|
|
||||||
|
def _build_token(self, admin_id: str, token_type: str, ttl: int) -> str:
|
||||||
|
now = int(time.time())
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"iss": self.config.issuer,
|
||||||
|
"jti": str(admin_id),
|
||||||
|
"iat": now,
|
||||||
|
"nbf": now,
|
||||||
|
"exp": now + ttl,
|
||||||
|
"token_type": token_type,
|
||||||
|
"nonce": secrets.token_urlsafe(16),
|
||||||
|
}
|
||||||
|
if token_type == "refresh":
|
||||||
|
payload["sub"] = "refresh"
|
||||||
|
|
||||||
|
header = {"typ": "JWT", "alg": "HS256"}
|
||||||
|
signing_input = ".".join(
|
||||||
|
[
|
||||||
|
self._b64url_json(header),
|
||||||
|
self._b64url_json(payload),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
signature = hmac.new(
|
||||||
|
self.config.secret.encode("utf-8"),
|
||||||
|
signing_input.encode("ascii"),
|
||||||
|
hashlib.sha256,
|
||||||
|
).digest()
|
||||||
|
return f"{signing_input}.{self._b64url_encode(signature)}"
|
||||||
|
|
||||||
|
def _parse(self, raw_token: str) -> JwtToken:
|
||||||
|
try:
|
||||||
|
header_b64, payload_b64, signature_b64 = raw_token.split(".", 2)
|
||||||
|
except ValueError as exc:
|
||||||
|
raise JwtInvalidError("Malformed token") from exc
|
||||||
|
|
||||||
|
signing_input = f"{header_b64}.{payload_b64}".encode("ascii")
|
||||||
|
expected = hmac.new(
|
||||||
|
self.config.secret.encode("utf-8"),
|
||||||
|
signing_input,
|
||||||
|
hashlib.sha256,
|
||||||
|
).digest()
|
||||||
|
if not hmac.compare_digest(self._b64url_encode(expected), signature_b64):
|
||||||
|
raise JwtInvalidError("Invalid signature")
|
||||||
|
|
||||||
|
try:
|
||||||
|
header = json.loads(self._b64url_decode(header_b64))
|
||||||
|
payload = json.loads(self._b64url_decode(payload_b64))
|
||||||
|
except (json.JSONDecodeError, UnicodeDecodeError) as exc:
|
||||||
|
raise JwtInvalidError("Invalid token payload") from exc
|
||||||
|
|
||||||
|
if header.get("alg") != "HS256":
|
||||||
|
raise JwtInvalidError("Unsupported algorithm")
|
||||||
|
if not isinstance(payload, dict):
|
||||||
|
raise JwtInvalidError("Invalid claims")
|
||||||
|
|
||||||
|
token = JwtToken(raw=raw_token, claims=payload)
|
||||||
|
self._validate_time(token)
|
||||||
|
return token
|
||||||
|
|
||||||
|
def _validate_time(self, token: JwtToken) -> None:
|
||||||
|
now = int(time.time())
|
||||||
|
exp = token.claims.get("exp")
|
||||||
|
nbf = token.claims.get("nbf")
|
||||||
|
iat = token.claims.get("iat")
|
||||||
|
|
||||||
|
if not isinstance(exp, int):
|
||||||
|
raise JwtInvalidError("Missing exp")
|
||||||
|
if exp <= now:
|
||||||
|
raise JwtExpiredError("Token expired")
|
||||||
|
if isinstance(nbf, int) and nbf > now:
|
||||||
|
raise JwtInvalidError("Token not active")
|
||||||
|
if isinstance(iat, int) and iat > now + 60:
|
||||||
|
raise JwtInvalidError("Token issued in future")
|
||||||
|
|
||||||
|
def _validate_issuer(self, token: JwtToken) -> None:
|
||||||
|
if token.claims.get("iss") != self.config.issuer:
|
||||||
|
raise JwtInvalidError("Invalid issuer")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _b64url_encode(raw: bytes) -> str:
|
||||||
|
return base64.urlsafe_b64encode(raw).decode("ascii").rstrip("=")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _b64url_json(cls, value: dict[str, Any]) -> str:
|
||||||
|
raw = json.dumps(value, separators=(",", ":"), sort_keys=True).encode("utf-8")
|
||||||
|
return cls._b64url_encode(raw)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _b64url_decode(value: str) -> str:
|
||||||
|
padding = "=" * (-len(value) % 4)
|
||||||
|
return base64.urlsafe_b64decode(value + padding).decode("utf-8")
|
||||||
23
app/lib/jwt/token.py
Normal file
23
app/lib/jwt/token.py
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True, slots=True)
|
||||||
|
class JwtToken:
|
||||||
|
raw: str
|
||||||
|
claims: dict[str, Any]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def jwt_id(self) -> str:
|
||||||
|
return str(self.claims.get("jti", ""))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def admin_id(self) -> int:
|
||||||
|
try:
|
||||||
|
return int(self.jwt_id)
|
||||||
|
except ValueError:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_refresh(self) -> bool:
|
||||||
|
return self.claims.get("sub") == "refresh" or self.claims.get("token_type") == "refresh"
|
||||||
1
app/lib/response/__init__.py
Normal file
1
app/lib/response/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Response helpers."""
|
||||||
6
app/lib/response/admin_return.py
Normal file
6
app/lib/response/admin_return.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from app.lib.response.common_return import CommonReturn
|
||||||
|
|
||||||
|
|
||||||
|
class AdminReturn(CommonReturn):
|
||||||
|
def after_success(self, response: dict) -> dict:
|
||||||
|
return response
|
||||||
34
app/lib/response/common_return.py
Normal file
34
app/lib/response/common_return.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
from app.constants.result_code import ResultCode
|
||||||
|
|
||||||
|
|
||||||
|
class CommonReturn:
|
||||||
|
def success(
|
||||||
|
self,
|
||||||
|
message: str = "success",
|
||||||
|
data: dict | list | None = None,
|
||||||
|
code: int | ResultCode = ResultCode.SUCCESS,
|
||||||
|
) -> dict:
|
||||||
|
return self.after_success(
|
||||||
|
{
|
||||||
|
"code": int(code),
|
||||||
|
"message": message,
|
||||||
|
"data": data if data is not None else {},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def error(
|
||||||
|
self,
|
||||||
|
message: str = "failed",
|
||||||
|
code: int | ResultCode = ResultCode.ERROR,
|
||||||
|
data: dict | None = None,
|
||||||
|
) -> dict:
|
||||||
|
return self.after_success(
|
||||||
|
{
|
||||||
|
"code": int(code),
|
||||||
|
"message": message,
|
||||||
|
"data": data or {},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def after_success(self, response: dict) -> dict:
|
||||||
|
raise NotImplementedError
|
||||||
44
app/main.py
Normal file
44
app/main.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import AsyncIterator
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from app.controller.admin.login_controller import router as admin_login_router
|
||||||
|
from app.controller.admin.profile_controller import router as admin_profile_router
|
||||||
|
from app.controller.api.health_controller import router as api_health_router
|
||||||
|
from app.core.config import get_settings
|
||||||
|
from app.core.dependencies import bootstrap_database
|
||||||
|
from app.exception.handler import register_exception_handlers
|
||||||
|
from app.lib.response.admin_return import AdminReturn
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI) -> AsyncIterator[None]:
|
||||||
|
await bootstrap_database()
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
def create_app() -> FastAPI:
|
||||||
|
settings = get_settings()
|
||||||
|
application = FastAPI(title=settings.app_name, lifespan=lifespan)
|
||||||
|
application.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=settings.cors_allow_origins,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
register_exception_handlers(application)
|
||||||
|
application.include_router(admin_login_router)
|
||||||
|
application.include_router(admin_profile_router)
|
||||||
|
application.include_router(api_health_router)
|
||||||
|
|
||||||
|
@application.get("/")
|
||||||
|
async def index() -> dict:
|
||||||
|
return AdminReturn().success("success", {"name": settings.app_name})
|
||||||
|
|
||||||
|
return application
|
||||||
|
|
||||||
|
|
||||||
|
app = create_app()
|
||||||
1
app/middleware/__init__.py
Normal file
1
app/middleware/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Middleware and route dependencies."""
|
||||||
1
app/middleware/admin/__init__.py
Normal file
1
app/middleware/admin/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Admin middleware."""
|
||||||
16
app/middleware/admin/admin_token_middleware.py
Normal file
16
app/middleware/admin/admin_token_middleware.py
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from app.common.context import current_admin_id
|
||||||
|
from app.lib.jwt.jwt import Jwt
|
||||||
|
from app.lib.jwt.token import JwtToken
|
||||||
|
from app.middleware.token.abstract_token_middleware import AbstractTokenMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
class AdminTokenMiddleware(AbstractTokenMiddleware):
|
||||||
|
async def parser_token(self, jwt: Jwt, raw_token: str) -> JwtToken:
|
||||||
|
return await jwt.parser_access_token(raw_token)
|
||||||
|
|
||||||
|
def set_context(self, request: Request, token: JwtToken) -> None:
|
||||||
|
admin_id = token.admin_id
|
||||||
|
current_admin_id.set(admin_id)
|
||||||
|
request.state.current_admin_id = admin_id
|
||||||
34
app/middleware/admin/permission_middleware.py
Normal file
34
app/middleware/admin/permission_middleware.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
from fastapi import Depends, Request
|
||||||
|
|
||||||
|
from app.common.context import current_admin_id
|
||||||
|
from app.common.repository.admin_user_repository import AdminUserRepository
|
||||||
|
from app.constants.admin_code import AdminCode
|
||||||
|
from app.constants.model.admin_user.admin_user_status_code import AdminUserStatusCode
|
||||||
|
from app.core.dependencies import get_admin_user_repository
|
||||||
|
from app.exception.err_exception import ErrException
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionMiddleware:
|
||||||
|
def __init__(self, permissions: Iterable[str] | None = None) -> None:
|
||||||
|
self.permissions = tuple(permissions or ())
|
||||||
|
|
||||||
|
async def __call__(
|
||||||
|
self,
|
||||||
|
request: Request,
|
||||||
|
user_repository: AdminUserRepository = Depends(get_admin_user_repository),
|
||||||
|
) -> None:
|
||||||
|
admin_id = getattr(request.state, "current_admin_id", current_admin_id.get())
|
||||||
|
if admin_id <= 0:
|
||||||
|
raise ErrException("账户不存在")
|
||||||
|
|
||||||
|
admin_user = await user_repository.find_by_id(admin_id)
|
||||||
|
if admin_user is None:
|
||||||
|
raise ErrException("账户不存在")
|
||||||
|
if admin_user.status == AdminUserStatusCode.DISABLE:
|
||||||
|
raise ErrException("账号已禁用", AdminCode.DISABLED)
|
||||||
|
|
||||||
|
request.state.current_admin_user = admin_user
|
||||||
|
if self.permissions and admin_user.user_type != "SuperAdmin":
|
||||||
|
raise ErrException("暂无权限", AdminCode.FORBIDDEN)
|
||||||
13
app/middleware/admin/refresh_admin_token_middleware.py
Normal file
13
app/middleware/admin/refresh_admin_token_middleware.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from app.lib.jwt.jwt import Jwt
|
||||||
|
from app.lib.jwt.token import JwtToken
|
||||||
|
from app.middleware.token.abstract_token_middleware import AbstractTokenMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshAdminTokenMiddleware(AbstractTokenMiddleware):
|
||||||
|
async def parser_token(self, jwt: Jwt, raw_token: str) -> JwtToken:
|
||||||
|
return await jwt.parser_refresh_token(raw_token)
|
||||||
|
|
||||||
|
def set_context(self, request: Request, token: JwtToken) -> None:
|
||||||
|
return None
|
||||||
1
app/middleware/token/__init__.py
Normal file
1
app/middleware/token/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Token middleware."""
|
||||||
72
app/middleware/token/abstract_token_middleware.py
Normal file
72
app/middleware/token/abstract_token_middleware.py
Normal file
@@ -0,0 +1,72 @@
|
|||||||
|
from fastapi import Depends, Request
|
||||||
|
|
||||||
|
from app.constants.result_code import ResultCode
|
||||||
|
from app.core.dependencies import get_token_service
|
||||||
|
from app.exception.err_exception import ErrException
|
||||||
|
from app.lib.jwt.exceptions import JwtError, JwtExpiredError
|
||||||
|
from app.lib.jwt.jwt import Jwt
|
||||||
|
from app.lib.jwt.token import JwtToken
|
||||||
|
from app.service.base_token_service import BaseTokenService
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractTokenMiddleware:
|
||||||
|
scene = "admin"
|
||||||
|
|
||||||
|
async def __call__(
|
||||||
|
self,
|
||||||
|
request: Request,
|
||||||
|
token_service: BaseTokenService = Depends(get_token_service),
|
||||||
|
) -> JwtToken:
|
||||||
|
raw_token = self.get_token(request)
|
||||||
|
jwt = token_service.get_jwt(self.scene)
|
||||||
|
|
||||||
|
try:
|
||||||
|
token = await self.parser_token(jwt, raw_token)
|
||||||
|
await token_service.check_jwt(jwt, token)
|
||||||
|
self.check_issuer(jwt, token)
|
||||||
|
except JwtExpiredError as exc:
|
||||||
|
raise ErrException(
|
||||||
|
"token过期",
|
||||||
|
ResultCode.JWT_EXPIRED,
|
||||||
|
{"err_msg": str(exc)},
|
||||||
|
) from exc
|
||||||
|
except JwtError as exc:
|
||||||
|
raise ErrException(
|
||||||
|
"token错误",
|
||||||
|
ResultCode.JWT_ERROR,
|
||||||
|
{"err_msg": str(exc)},
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
self.set_context(request, token)
|
||||||
|
request.state.token = token
|
||||||
|
return token
|
||||||
|
|
||||||
|
async def parser_token(self, jwt: Jwt, raw_token: str) -> JwtToken:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def set_context(self, request: Request, token: JwtToken) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_issuer(jwt: Jwt, token: JwtToken) -> None:
|
||||||
|
if token.claims.get("iss") != jwt.issuer:
|
||||||
|
raise ErrException("token错误", ResultCode.JWT_ERROR)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_token(request: Request) -> str:
|
||||||
|
authorization = request.headers.get("authorization")
|
||||||
|
if authorization:
|
||||||
|
scheme, _, token = authorization.partition(" ")
|
||||||
|
if scheme.lower() == "bearer" and token:
|
||||||
|
return token.strip()
|
||||||
|
return authorization.replace("Bearer ", "", 1).strip()
|
||||||
|
|
||||||
|
token_header = request.headers.get("token")
|
||||||
|
if token_header:
|
||||||
|
return token_header.strip()
|
||||||
|
|
||||||
|
token_query = request.query_params.get("token")
|
||||||
|
if token_query:
|
||||||
|
return token_query.strip()
|
||||||
|
|
||||||
|
raise ErrException("token缺失", ResultCode.JWT_ERROR)
|
||||||
1
app/model/__init__.py
Normal file
1
app/model/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Model layer."""
|
||||||
53
app/model/admin_user.py
Normal file
53
app/model/admin_user.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.common.security.password_hasher import verify_password
|
||||||
|
from app.constants.model.admin_user.admin_user_status_code import AdminUserStatusCode
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class AdminUser:
|
||||||
|
id: int
|
||||||
|
username: str
|
||||||
|
password: str
|
||||||
|
user_type: str
|
||||||
|
nickname: str
|
||||||
|
phone: str
|
||||||
|
email: str
|
||||||
|
status: AdminUserStatusCode
|
||||||
|
login_ip: str
|
||||||
|
login_time: str
|
||||||
|
remark: str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_row(cls, row: dict[str, Any]) -> "AdminUser":
|
||||||
|
return cls(
|
||||||
|
id=int(row["id"]),
|
||||||
|
username=str(row["username"]),
|
||||||
|
password=str(row["password"]),
|
||||||
|
user_type=str(row["user_type"]),
|
||||||
|
nickname=str(row["nickname"]),
|
||||||
|
phone=str(row["phone"]),
|
||||||
|
email=str(row["email"]),
|
||||||
|
status=AdminUserStatusCode(int(row["status"])),
|
||||||
|
login_ip=str(row["login_ip"]),
|
||||||
|
login_time=str(row["login_time"]),
|
||||||
|
remark=str(row["remark"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
def verify_password(self, password: str) -> bool:
|
||||||
|
return verify_password(password, self.password)
|
||||||
|
|
||||||
|
def to_public_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"username": self.username,
|
||||||
|
"user_type": self.user_type,
|
||||||
|
"nickname": self.nickname,
|
||||||
|
"phone": self.phone,
|
||||||
|
"email": self.email,
|
||||||
|
"status": int(self.status),
|
||||||
|
"login_ip": self.login_ip,
|
||||||
|
"login_time": self.login_time,
|
||||||
|
"remark": self.remark,
|
||||||
|
}
|
||||||
1
app/request/__init__.py
Normal file
1
app/request/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Request schemas."""
|
||||||
1
app/request/admin/__init__.py
Normal file
1
app/request/admin/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Admin request schemas."""
|
||||||
6
app/request/admin/login_request.py
Normal file
6
app/request/admin/login_request.py
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class LoginRequest(BaseModel):
|
||||||
|
username: str = Field(min_length=1)
|
||||||
|
password: str = Field(min_length=1)
|
||||||
1
app/service/__init__.py
Normal file
1
app/service/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Service layer."""
|
||||||
1
app/service/admin/__init__.py
Normal file
1
app/service/admin/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Admin services."""
|
||||||
15
app/service/admin/base_admin_service.py
Normal file
15
app/service/admin/base_admin_service.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
from app.common.context import current_admin_id
|
||||||
|
from app.exception.err_exception import ErrException
|
||||||
|
from app.lib.response.admin_return import AdminReturn
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAdminService:
|
||||||
|
def __init__(self, admin_return: AdminReturn) -> None:
|
||||||
|
self.admin_return = admin_return
|
||||||
|
|
||||||
|
@property
|
||||||
|
def admin_id(self) -> int:
|
||||||
|
admin_id = current_admin_id.get()
|
||||||
|
if admin_id <= 0:
|
||||||
|
raise ErrException("账户不存在")
|
||||||
|
return admin_id
|
||||||
1
app/service/admin/login/__init__.py
Normal file
1
app/service/admin/login/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Admin login services."""
|
||||||
56
app/service/admin/login/login_service.py
Normal file
56
app/service/admin/login/login_service.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
import asyncio
|
||||||
|
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
from app.common.repository.admin_user_repository import AdminUserRepository
|
||||||
|
from app.constants.model.admin_user.admin_user_status_code import AdminUserStatusCode
|
||||||
|
from app.exception.err_exception import ErrException
|
||||||
|
from app.lib.response.admin_return import AdminReturn
|
||||||
|
from app.request.admin.login_request import LoginRequest
|
||||||
|
from app.service.admin.base_admin_service import BaseAdminService
|
||||||
|
from app.service.base_token_service import BaseTokenService
|
||||||
|
|
||||||
|
|
||||||
|
class LoginService(BaseAdminService):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
user_repository: AdminUserRepository,
|
||||||
|
token_service: BaseTokenService,
|
||||||
|
admin_return: AdminReturn,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(admin_return)
|
||||||
|
self.user_repository = user_repository
|
||||||
|
self.token_service = token_service
|
||||||
|
|
||||||
|
async def handle(self, payload: LoginRequest, request: Request) -> dict:
|
||||||
|
admin_info = await self.user_repository.find_by_username(payload.username)
|
||||||
|
if admin_info is None:
|
||||||
|
raise ErrException("后台管理员不存在")
|
||||||
|
|
||||||
|
password_valid = await asyncio.to_thread(
|
||||||
|
admin_info.verify_password,
|
||||||
|
payload.password,
|
||||||
|
)
|
||||||
|
if not password_valid:
|
||||||
|
raise ErrException("密码错误")
|
||||||
|
|
||||||
|
if admin_info.status == AdminUserStatusCode.DISABLE:
|
||||||
|
raise ErrException("用户已禁用")
|
||||||
|
|
||||||
|
await self.user_repository.record_login(admin_info.id, self._client_ip(request))
|
||||||
|
jwt = self.token_service.get_jwt("admin")
|
||||||
|
return self.admin_return.success(
|
||||||
|
"success",
|
||||||
|
{
|
||||||
|
"access_token": jwt.builder_access_token(str(admin_info.id)),
|
||||||
|
"refresh_token": jwt.builder_refresh_token(str(admin_info.id)),
|
||||||
|
"expire_at": int(jwt.get_config("ttl", 0)),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _client_ip(request: Request) -> str:
|
||||||
|
forwarded = request.headers.get("x-forwarded-for")
|
||||||
|
if forwarded:
|
||||||
|
return forwarded.split(",", 1)[0].strip()
|
||||||
|
return request.client.host if request.client else ""
|
||||||
22
app/service/admin/login/refresh_service.py
Normal file
22
app/service/admin/login/refresh_service.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
from app.lib.jwt.token import JwtToken
|
||||||
|
from app.lib.response.admin_return import AdminReturn
|
||||||
|
from app.service.admin.base_admin_service import BaseAdminService
|
||||||
|
from app.service.base_token_service import BaseTokenService
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshService(BaseAdminService):
|
||||||
|
def __init__(self, token_service: BaseTokenService, admin_return: AdminReturn) -> None:
|
||||||
|
super().__init__(admin_return)
|
||||||
|
self.token_service = token_service
|
||||||
|
|
||||||
|
async def handle(self, token: JwtToken) -> dict:
|
||||||
|
return self.admin_return.success("success", await self.refresh_token(token))
|
||||||
|
|
||||||
|
async def refresh_token(self, token: JwtToken) -> dict[str, int | str]:
|
||||||
|
jwt = self.token_service.get_jwt("admin")
|
||||||
|
await jwt.add_blacklist(token)
|
||||||
|
return {
|
||||||
|
"access_token": jwt.builder_access_token(token.jwt_id),
|
||||||
|
"refresh_token": jwt.builder_refresh_token(token.jwt_id),
|
||||||
|
"expire_at": int(jwt.get_config("ttl", 0)),
|
||||||
|
}
|
||||||
1
app/service/admin/profile/__init__.py
Normal file
1
app/service/admin/profile/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Admin profile services."""
|
||||||
20
app/service/admin/profile/current_user_service.py
Normal file
20
app/service/admin/profile/current_user_service.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from app.common.repository.admin_user_repository import AdminUserRepository
|
||||||
|
from app.exception.err_exception import ErrException
|
||||||
|
from app.lib.response.admin_return import AdminReturn
|
||||||
|
from app.service.admin.base_admin_service import BaseAdminService
|
||||||
|
|
||||||
|
|
||||||
|
class CurrentUserService(BaseAdminService):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
user_repository: AdminUserRepository,
|
||||||
|
admin_return: AdminReturn,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(admin_return)
|
||||||
|
self.user_repository = user_repository
|
||||||
|
|
||||||
|
async def handle(self) -> dict:
|
||||||
|
admin_user = await self.user_repository.find_by_id(self.admin_id)
|
||||||
|
if admin_user is None:
|
||||||
|
raise ErrException("账户不存在")
|
||||||
|
return self.admin_return.success("success", admin_user.to_public_dict())
|
||||||
17
app/service/base_token_service.py
Normal file
17
app/service/base_token_service.py
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
from app.constants.result_code import ResultCode
|
||||||
|
from app.exception.err_exception import ErrException
|
||||||
|
from app.lib.jwt.factory import JwtFactory
|
||||||
|
from app.lib.jwt.jwt import Jwt
|
||||||
|
from app.lib.jwt.token import JwtToken
|
||||||
|
|
||||||
|
|
||||||
|
class BaseTokenService:
|
||||||
|
def __init__(self, jwt_factory: JwtFactory) -> None:
|
||||||
|
self.jwt_factory = jwt_factory
|
||||||
|
|
||||||
|
def get_jwt(self, scene: str = "admin") -> Jwt:
|
||||||
|
return self.jwt_factory.get(scene)
|
||||||
|
|
||||||
|
async def check_jwt(self, jwt: Jwt, token: JwtToken) -> None:
|
||||||
|
if await jwt.has_blacklist(token):
|
||||||
|
raise ErrException("token已过期", ResultCode.JWT_EXPIRED)
|
||||||
14
pyproject.toml
Normal file
14
pyproject.toml
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
[project]
|
||||||
|
name = "py-server"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Async API server scaffold inspired by the Hyperf app layering."
|
||||||
|
requires-python = ">=3.14"
|
||||||
|
dependencies = [
|
||||||
|
"fastapi>=0.136.0",
|
||||||
|
"pydantic-settings>=2.14.0",
|
||||||
|
"uvicorn[standard]>=0.49.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.pyright]
|
||||||
|
venvPath = "."
|
||||||
|
venv = ".venv"
|
||||||
1
tests/__init__.py
Normal file
1
tests/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Test package."""
|
||||||
88
tests/test_admin_login_flow.py
Normal file
88
tests/test_admin_login_flow.py
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
DB_PATH = Path(tempfile.gettempdir()) / "py_server_admin_login_test.db"
|
||||||
|
DB_PATH.unlink(missing_ok=True)
|
||||||
|
os.environ["DATABASE_PATH"] = str(DB_PATH)
|
||||||
|
os.environ["JWT_ADMIN_SECRET"] = "test_admin_secret"
|
||||||
|
os.environ["ADMIN_SEED_USERNAME"] = "admin"
|
||||||
|
os.environ["ADMIN_SEED_PASSWORD"] = "admin"
|
||||||
|
|
||||||
|
from httpx import ASGITransport, AsyncClient
|
||||||
|
|
||||||
|
from app.core.dependencies import bootstrap_database
|
||||||
|
from app.main import app
|
||||||
|
|
||||||
|
|
||||||
|
class AdminLoginFlowTest(unittest.IsolatedAsyncioTestCase):
|
||||||
|
async def asyncSetUp(self) -> None:
|
||||||
|
await bootstrap_database()
|
||||||
|
self.client = AsyncClient(
|
||||||
|
transport=ASGITransport(app=app),
|
||||||
|
base_url="http://testserver",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def asyncTearDown(self) -> None:
|
||||||
|
await self.client.aclose()
|
||||||
|
|
||||||
|
async def test_login_access_and_refresh_flow(self) -> None:
|
||||||
|
login_response = await self.client.post(
|
||||||
|
"/admin/login/login",
|
||||||
|
json={"username": "admin", "password": "admin"},
|
||||||
|
)
|
||||||
|
login_payload = login_response.json()
|
||||||
|
|
||||||
|
self.assertEqual(login_payload["code"], 0)
|
||||||
|
self.assertIn("access_token", login_payload["data"])
|
||||||
|
self.assertIn("refresh_token", login_payload["data"])
|
||||||
|
self.assertEqual(login_payload["data"]["expire_at"], 3600)
|
||||||
|
|
||||||
|
access_token = login_payload["data"]["access_token"]
|
||||||
|
refresh_token = login_payload["data"]["refresh_token"]
|
||||||
|
|
||||||
|
current_response = await self.client.get(
|
||||||
|
"/admin/profile/current",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
current_payload = current_response.json()
|
||||||
|
|
||||||
|
self.assertEqual(current_payload["code"], 0)
|
||||||
|
self.assertEqual(current_payload["data"]["username"], "admin")
|
||||||
|
|
||||||
|
refresh_response = await self.client.post(
|
||||||
|
"/admin/login/refresh",
|
||||||
|
headers={"Authorization": f"Bearer {refresh_token}"},
|
||||||
|
)
|
||||||
|
refresh_payload = refresh_response.json()
|
||||||
|
|
||||||
|
self.assertEqual(refresh_payload["code"], 0)
|
||||||
|
self.assertNotEqual(refresh_payload["data"]["refresh_token"], refresh_token)
|
||||||
|
|
||||||
|
reused_response = await self.client.post(
|
||||||
|
"/admin/login/refresh",
|
||||||
|
headers={"Authorization": f"Bearer {refresh_token}"},
|
||||||
|
)
|
||||||
|
reused_payload = reused_response.json()
|
||||||
|
|
||||||
|
self.assertEqual(reused_payload["code"], 10001)
|
||||||
|
|
||||||
|
async def test_refresh_endpoint_rejects_access_token(self) -> None:
|
||||||
|
login_response = await self.client.post(
|
||||||
|
"/admin/login/login",
|
||||||
|
json={"username": "admin", "password": "admin"},
|
||||||
|
)
|
||||||
|
access_token = login_response.json()["data"]["access_token"]
|
||||||
|
|
||||||
|
refresh_response = await self.client.post(
|
||||||
|
"/admin/login/refresh",
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
refresh_payload = refresh_response.json()
|
||||||
|
|
||||||
|
self.assertEqual(refresh_payload["code"], 10002)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user