#!/usr/bin/env python3
"""
完整的SSO服务器 - 包含认证API
"""
import sys
import os
import socket
import json
import uuid
# 添加src目录到Python路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
# 加载环境变量
from dotenv import load_dotenv
# 加载配置文件 - 从 src/app/config/.env
env_path = os.path.join(os.path.dirname(__file__), "..", "src", "app", "config", ".env")
load_dotenv(dotenv_path=env_path)
from fastapi import FastAPI, HTTPException, Depends, Request, Response, BackgroundTasks
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import BaseModel
from typing import Optional, Any, Union
import hashlib
import secrets
import requests
from urllib.parse import urlparse
# MIME 类型到后缀的映射
MIME_MAP = {
'application/pdf': '.pdf',
'application/vnd.openxmlformats-officedocument.wordprocessingml.document': '.docx',
'application/msword': '.doc',
'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet': '.xlsx',
'application/vnd.ms-excel': '.xls',
'application/vnd.openxmlformats-officedocument.presentationml.presentation': '.pptx',
'application/vnd.ms-powerpoint': '.ppt',
'text/markdown': '.md',
'text/plain': '.txt',
'text/html': '.html',
'image/jpeg': '.jpg',
'image/png': '.png',
'application/zip': '.zip',
}
def detect_file_extension(url: str) -> str:
"""通过 URL 路径或 HEAD 请求检测文件后缀"""
if not url:
return ""
# 1. 尝试从路径解析
path = urlparse(url).path
ext = os.path.splitext(path)[1].lower()
if ext and len(ext) <= 6:
return ext
# 2. 尝试 HEAD 请求检测 Content-Type
try:
response = requests.head(url, allow_redirects=True, timeout=5)
content_type = response.headers.get('Content-Type', '').split(';')[0].strip()
return MIME_MAP.get(content_type, "")
except Exception as e:
print(f"检测文件后缀失败: {e}")
return ""
# 修复JWT导入 - 确保使用正确的JWT库
try:
# 首先尝试使用PyJWT
import jwt as pyjwt
# 测试是否有encode方法
test_token = pyjwt.encode({"test": "data"}, "secret", algorithm="HS256")
jwt = pyjwt
print("✅ 使用PyJWT库")
except (ImportError, AttributeError, TypeError) as e:
print(f"PyJWT导入失败: {e}")
try:
# 尝试使用python-jose
from jose import jwt
print("✅ 使用python-jose库")
except ImportError as e:
print(f"python-jose导入失败: {e}")
# 最后尝试安装PyJWT
print("尝试安装PyJWT...")
import subprocess
import sys
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", "PyJWT"])
import jwt
print("✅ PyJWT安装成功")
except Exception as install_error:
print(f"❌ PyJWT安装失败: {install_error}")
raise ImportError("无法导入JWT库,请手动安装: pip install PyJWT")
from datetime import datetime, timedelta, timezone, date
import pymysql
from urllib.parse import urlparse
# 导入RBAC API - 移除循环导入
# from rbac_api import get_user_menus, get_all_menus, get_all_roles, get_user_permissions
# 数据模型
class LoginRequest(BaseModel):
username: str
password: str
remember_me: bool = False
class TokenResponse(BaseModel):
access_token: str
refresh_token: Optional[str] = None
token_type: str = "Bearer"
expires_in: int
scope: Optional[str] = None
class UserInfo(BaseModel):
id: str
username: str
email: str
phone: Optional[str] = None
avatar_url: Optional[str] = None
is_active: bool
is_superuser: bool = False
roles: list = []
permissions: list = []
class ApiResponse(BaseModel):
code: int
message: str
data: Optional[Any] = None
timestamp: str
# 文档管理数据模型
# --- 文档中心配置 ---
TABLE_MAP = {
"basis": "t_basis_of_preparation", # 编制依据
"work": "t_work_of_preparation", # 施工方案
"job": "t_job_of_preparation" # 办公制度
}
def get_db_connection():
"""获取数据库连接"""
try:
# 导入配置
from app.core.config import config_handler
database_url = config_handler.get("admin_app", "DATABASE_URL", "")
if not database_url:
return None
parsed = urlparse(database_url)
config = {
'host': parsed.hostname or 'localhost',
'port': parsed.port or 3306,
'user': parsed.username or 'root',
'password': parsed.password or '',
'database': parsed.path[1:] if parsed.path else 'sso_db',
'charset': 'utf8mb4'
}
return pymysql.connect(**config)
except Exception as e:
print(f"数据库连接失败: {e}")
return None
# --- 初始化主表 ---
def init_master_table():
"""初始化主表结构,并确保所有必要字段都存在"""
conn = get_db_connection()
if not conn:
return
try:
cursor = conn.cursor()
# 1. 创建主表 (如果不存在)
cursor.execute("""
CREATE TABLE IF NOT EXISTS t_document_main (
id CHAR(36) PRIMARY KEY,
title VARCHAR(255) NOT NULL,
standard_no VARCHAR(100),
issuing_authority VARCHAR(255),
release_date DATE,
document_type VARCHAR(100),
professional_field VARCHAR(100),
validity VARCHAR(50) DEFAULT '现行',
created_by VARCHAR(100),
created_time DATETIME DEFAULT CURRENT_TIMESTAMP,
updated_time DATETIME DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
conversion_status TINYINT DEFAULT 0, -- 0:待转化, 1:转化中, 2:已完成, 3:失败
conversion_progress INT DEFAULT 0,
converted_file_name VARCHAR(255),
conversion_error TEXT,
whether_to_enter TINYINT DEFAULT 0, -- 0:未入库, 1:已入库
source_type ENUM('basis', 'work', 'job') NOT NULL,
source_id CHAR(36) NOT NULL,
file_url TEXT,
file_extension VARCHAR(10),
content TEXT,
primary_category_id INT,
secondary_category_id INT,
year INT
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;
""")
conn.commit()
print("✅ 主表 t_document_main 初始化成功")
except Exception as e:
print(f"❌ 初始化主表失败: {e}")
finally:
conn.close()
# 执行初始化
init_master_table()
def get_table_name(table_type: Optional[str]) -> str:
"""根据类型获取对应的数据库表名,默认为编制依据"""
return TABLE_MAP.get(table_type, "t_basis_of_preparation")
class DocumentAdd(BaseModel):
title: str
content: str
primary_category_id: Optional[Any] = None
secondary_category_id: Optional[Any] = None
year: Optional[int] = None
table_type: Optional[str] = "basis"
# 新增编辑需要的字段
id: Optional[str] = None
source_id: Optional[str] = None
# 扩展字段 (子表特有属性)
standard_no: Optional[str] = None
issuing_authority: Optional[str] = None
release_date: Optional[str] = None
document_type: Optional[str] = None
professional_field: Optional[str] = None
validity: Optional[str] = None
project_name: Optional[str] = None
project_section: Optional[str] = None
# 文件相关字段
file_url: Optional[str] = None
file_extension: Optional[str] = None
class DocumentListRequest(BaseModel):
page: int = 1
size: int = 50
keyword: Optional[str] = None
table_type: Optional[str] = None
whether_to_enter: Optional[int] = None
# 导入配置
from app.core.config import config_handler
# 配置
JWT_SECRET_KEY = config_handler.get("admin_app", "JWT_SECRET_KEY", "dev-jwt-secret-key-12345")
ACCESS_TOKEN_EXPIRE_MINUTES = config_handler.get_int("admin_app", "ACCESS_TOKEN_EXPIRE_MINUTES", 30)
def check_port(port):
"""检查端口是否可用"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
try:
s.bind(('localhost', port))
return True
except OSError:
return False
def find_available_port(start_port=8000, max_port=8010):
"""查找可用端口"""
for port in range(start_port, max_port + 1):
if check_port(port):
return port
return None
def verify_password_simple(password: str, stored_hash: str) -> bool:
"""验证密码(简化版)"""
if stored_hash.startswith("sha256$"):
parts = stored_hash.split("$")
if len(parts) == 3:
salt = parts[1]
expected_hash = parts[2]
actual_hash = hashlib.sha256((password + salt).encode()).hexdigest()
return actual_hash == expected_hash
return False
def create_access_token(data: dict) -> str:
"""创建访问令牌"""
to_encode = data.copy()
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": expire, "iat": datetime.now(timezone.utc)})
encoded_jwt = jwt.encode(to_encode, JWT_SECRET_KEY, algorithm="HS256")
return encoded_jwt
def verify_token(token: str) -> Optional[dict]:
"""验证令牌"""
try:
payload = jwt.decode(token, JWT_SECRET_KEY, algorithms=["HS256"])
return payload
except jose.PyJWTError:
return None
# 创建FastAPI应用
app = FastAPI(
title="SSO认证中心",
version="1.0.0",
description="OAuth2单点登录认证中心",
docs_url="/docs",
redoc_url="/redoc"
)
# 配置CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
security = HTTPBearer()
security_optional = HTTPBearer(auto_error=False)
@app.get("/")
async def root():
"""根路径"""
return ApiResponse(
code=0,
message="欢迎使用SSO认证中心",
data={
"name": "SSO认证中心",
"version": "1.0.0",
"docs": "/docs"
},
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
@app.get("/health")
async def health_check():
"""健康检查"""
return ApiResponse(
code=0,
message="服务正常运行",
data={
"status": "healthy",
"version": "1.0.0",
"timestamp": datetime.now(timezone.utc).isoformat()
},
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
@app.post("/api/v1/auth/login")
async def login(request: Request, login_data: LoginRequest):
"""用户登录"""
print(f"🔐 收到登录请求: username={login_data.username}")
conn = None
cursor = None
try:
# 获取数据库连接
print("📊 尝试连接数据库...")
conn = get_db_connection()
if not conn:
print("❌ 数据库连接失败")
return ApiResponse(
code=500001,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
print("✅ 数据库连接成功")
cursor = conn.cursor()
# 查找用户
print(f"🔍 查找用户: {login_data.username}")
cursor.execute(
"SELECT id, username, email, password_hash, is_active, is_superuser FROM users WHERE username = %s OR email = %s",
(login_data.username, login_data.username)
)
user_data = cursor.fetchone()
print(f"👤 用户查询结果: {user_data is not None}")
if not user_data:
print("❌ 用户不存在")
return ApiResponse(
code=200001,
message="用户名或密码错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
user_id, username, email, password_hash, is_active, is_superuser = user_data
print(f"✅ 找到用户: {username}, 激活状态: {is_active}")
# 检查用户状态
if not is_active:
print("❌ 用户已被禁用")
return ApiResponse(
code=200002,
message="用户已被禁用",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 验证密码
print(f"🔑 验证密码,哈希格式: {password_hash[:20]}...")
password_valid = verify_password_simple(login_data.password, password_hash)
print(f"🔑 密码验证结果: {password_valid}")
if not password_valid:
print("❌ 密码验证失败")
return ApiResponse(
code=200001,
message="用户名或密码错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 生成令牌
print("🎫 生成访问令牌...")
token_data = {
"sub": user_id,
"username": username,
"email": email,
"is_superuser": is_superuser
}
access_token = create_access_token(token_data)
print(f"✅ 令牌生成成功: {access_token[:50]}...")
token_response = TokenResponse(
access_token=access_token,
token_type="Bearer",
expires_in=ACCESS_TOKEN_EXPIRE_MINUTES * 60,
scope="profile email"
)
print("🎉 登录成功")
return ApiResponse(
code=0,
message="登录成功",
data=token_response.model_dump(),
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"❌ 登录错误详情: {type(e).__name__}: {str(e)}")
import traceback
print(f"❌ 错误堆栈: {traceback.format_exc()}")
return ApiResponse(
code=500001,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
finally:
if cursor:
cursor.close()
if conn:
conn.close()
@app.get("/api/v1/users/profile")
async def get_user_profile(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""获取用户资料"""
try:
# 验证令牌
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=200002,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
user_id = payload.get("sub")
if not user_id:
return ApiResponse(
code=200002,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 获取数据库连接
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500001,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 查找用户详细信息
cursor.execute("""
SELECT u.id, u.username, u.email, u.phone, u.avatar_url, u.is_active, u.is_superuser,
u.last_login_at, u.created_at, u.updated_at,
p.real_name, p.company, p.department, p.position
FROM users u
LEFT JOIN user_profiles p ON u.id = p.user_id
WHERE u.id = %s
""", (user_id,))
user_data = cursor.fetchone()
# 获取用户角色
cursor.execute("""
SELECT r.name
FROM user_roles ur
JOIN roles r ON ur.role_id = r.id
WHERE ur.user_id = %s AND ur.is_active = 1
""", (user_id,))
roles = [row[0] for row in cursor.fetchall()]
cursor.close()
conn.close()
if not user_data:
return ApiResponse(
code=200001,
message="用户不存在",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 构建用户信息
user_info = {
"id": user_data[0],
"username": user_data[1],
"email": user_data[2],
"phone": user_data[3],
"avatar_url": user_data[4],
"is_active": user_data[5],
"is_superuser": user_data[6],
"last_login_at": user_data[7].isoformat() if user_data[7] else None,
"created_at": user_data[8].isoformat() if user_data[8] else None,
"updated_at": user_data[9].isoformat() if user_data[9] else None,
"real_name": user_data[10],
"company": user_data[11],
"department": user_data[12],
"position": user_data[13],
"roles": roles
}
return ApiResponse(
code=0,
message="获取用户资料成功",
data=user_info,
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"获取用户资料错误: {e}")
return ApiResponse(
code=500001,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
@app.put("/api/v1/users/profile")
async def update_user_profile(
request: Request,
profile_data: dict,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""更新用户资料"""
try:
# 验证令牌
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=200002,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
user_id = payload.get("sub")
# 获取数据库连接
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500001,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 更新用户基本信息
update_fields = []
update_values = []
if 'email' in profile_data:
update_fields.append('email = %s')
update_values.append(profile_data['email'])
if 'phone' in profile_data:
update_fields.append('phone = %s')
update_values.append(profile_data['phone'])
if update_fields:
update_values.append(user_id)
cursor.execute(f"""
UPDATE users
SET {', '.join(update_fields)}, updated_at = NOW()
WHERE id = %s
""", update_values)
# 更新或插入用户详情
profile_fields = ['real_name', 'company', 'department', 'position']
profile_updates = {k: v for k, v in profile_data.items() if k in profile_fields}
if profile_updates:
# 检查是否已有记录
cursor.execute("SELECT id FROM user_profiles WHERE user_id = %s", (user_id,))
profile_exists = cursor.fetchone()
if profile_exists:
# 更新现有记录
update_fields = []
update_values = []
for field, value in profile_updates.items():
update_fields.append(f'{field} = %s')
update_values.append(value)
update_values.append(user_id)
cursor.execute(f"""
UPDATE user_profiles
SET {', '.join(update_fields)}, updated_at = NOW()
WHERE user_id = %s
""", update_values)
else:
# 插入新记录
fields = ['user_id'] + list(profile_updates.keys())
values = [user_id] + list(profile_updates.values())
placeholders = ', '.join(['%s'] * len(values))
cursor.execute(f"""
INSERT INTO user_profiles ({', '.join(fields)}, created_at, updated_at)
VALUES ({placeholders}, NOW(), NOW())
""", values)
conn.commit()
cursor.close()
conn.close()
return ApiResponse(
code=0,
message="用户资料更新成功",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"更新用户资料错误: {e}")
return ApiResponse(
code=500001,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
@app.put("/api/v1/users/password")
async def change_user_password(
request: Request,
password_data: dict,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""修改用户密码"""
try:
# 验证令牌
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=200002,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
user_id = payload.get("sub")
old_password = password_data.get('old_password')
new_password = password_data.get('new_password')
if not old_password or not new_password:
return ApiResponse(
code=100001,
message="缺少必要参数",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 获取数据库连接
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500001,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 验证当前密码
cursor.execute("SELECT password_hash FROM users WHERE id = %s", (user_id,))
result = cursor.fetchone()
if not result or not verify_password_simple(old_password, result[0]):
cursor.close()
conn.close()
return ApiResponse(
code=200001,
message="当前密码错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 生成新密码哈希
new_password_hash = hash_password_simple(new_password)
# 更新密码
cursor.execute("""
UPDATE users
SET password_hash = %s, updated_at = NOW()
WHERE id = %s
""", (new_password_hash, user_id))
conn.commit()
cursor.close()
conn.close()
return ApiResponse(
code=0,
message="密码修改成功",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"修改密码错误: {e}")
return ApiResponse(
code=500001,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
def hash_password_simple(password):
"""简单的密码哈希"""
import hashlib
import secrets
# 生成盐值
salt = secrets.token_hex(16)
# 使用SHA256哈希
password_hash = hashlib.sha256((password + salt).encode()).hexdigest()
return f"sha256${salt}${password_hash}"
@app.post("/api/v1/auth/logout")
async def logout():
"""用户登出"""
return ApiResponse(
code=0,
message="登出成功",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# OAuth2 授权端点
@app.get("/oauth/authorize")
async def oauth_authorize(
response_type: str,
client_id: str,
redirect_uri: str,
scope: str = "profile",
state: str = None
):
"""OAuth2授权端点"""
try:
print(f"🔐 OAuth授权请求: client_id={client_id}, redirect_uri={redirect_uri}, scope={scope}")
# 验证必要参数
if not response_type or not client_id or not redirect_uri:
error_url = f"{redirect_uri}?error=invalid_request&error_description=Missing required parameters"
if state:
error_url += f"&state={state}"
return {"error": "invalid_request", "redirect_url": error_url}
# 验证response_type
if response_type != "code":
error_url = f"{redirect_uri}?error=unsupported_response_type&error_description=Only authorization code flow is supported"
if state:
error_url += f"&state={state}"
return {"error": "unsupported_response_type", "redirect_url": error_url}
# 获取数据库连接
conn = get_db_connection()
if not conn:
error_url = f"{redirect_uri}?error=server_error&error_description=Database connection failed"
if state:
error_url += f"&state={state}"
return {"error": "server_error", "redirect_url": error_url}
cursor = conn.cursor()
# 验证client_id和redirect_uri
cursor.execute("""
SELECT id, name, redirect_uris, scope, is_active, is_trusted
FROM apps
WHERE app_key = %s AND is_active = 1
""", (client_id,))
app_data = cursor.fetchone()
cursor.close()
conn.close()
if not app_data:
error_url = f"{redirect_uri}?error=invalid_client&error_description=Invalid client_id"
if state:
error_url += f"&state={state}"
return {"error": "invalid_client", "redirect_url": error_url}
app_id, app_name, redirect_uris_json, app_scope_json, is_active, is_trusted = app_data
# 验证redirect_uri
redirect_uris = json.loads(redirect_uris_json) if redirect_uris_json else []
if redirect_uri not in redirect_uris:
error_url = f"{redirect_uri}?error=invalid_request&error_description=Invalid redirect_uri"
if state:
error_url += f"&state={state}"
return {"error": "invalid_request", "redirect_url": error_url}
# 验证scope
app_scopes = json.loads(app_scope_json) if app_scope_json else []
requested_scopes = scope.split() if scope else []
invalid_scopes = [s for s in requested_scopes if s not in app_scopes]
if invalid_scopes:
error_url = f"{redirect_uri}?error=invalid_scope&error_description=Invalid scope: {' '.join(invalid_scopes)}"
if state:
error_url += f"&state={state}"
return {"error": "invalid_scope", "redirect_url": error_url}
# TODO: 检查用户登录状态
# 这里应该检查用户是否已登录(通过session或cookie)
# 如果未登录,应该重定向到登录页面
# 临时方案:返回登录页面,让用户先登录
# 生产环境应该使用session管理
# 构建登录页面URL,登录后返回授权页面
login_page_url = f"/oauth/login?response_type={response_type}&client_id={client_id}&redirect_uri={redirect_uri}&scope={scope}"
if state:
login_page_url += f"&state={state}"
print(f"🔐 需要用户登录,重定向到登录页面: {login_page_url}")
from fastapi.responses import RedirectResponse
return RedirectResponse(url=login_page_url, status_code=302)
# 非受信任应用需要用户授权确认
# 这里返回授权页面HTML
authorization_html = f"""
授权确认 - SSO认证中心
{app_name}
该应用请求以下权限:
"""
# 添加权限列表
scope_descriptions = {
"profile": "访问您的基本信息(用户名、头像等)",
"email": "访问您的邮箱地址",
"phone": "访问您的手机号码",
"roles": "访问您的角色和权限信息"
}
for scope_item in requested_scopes:
description = scope_descriptions.get(scope_item, f"访问 {scope_item} 信息")
authorization_html += f"- {description}
"
authorization_html += f"""
"""
from fastapi.responses import HTMLResponse
return HTMLResponse(content=authorization_html)
except Exception as e:
print(f"❌ OAuth授权错误: {e}")
error_url = f"{redirect_uri}?error=server_error&error_description=Internal server error"
if state:
error_url += f"&state={state}"
return {"error": "server_error", "redirect_url": error_url}
@app.get("/oauth/login")
async def oauth_login_page(
response_type: str,
client_id: str,
redirect_uri: str,
scope: str = "profile",
state: str = None
):
"""OAuth2登录页面"""
try:
print(f"🔐 显示OAuth登录页面: client_id={client_id}")
# 获取应用信息
conn = get_db_connection()
if not conn:
return {"error": "server_error", "message": "数据库连接失败"}
cursor = conn.cursor()
cursor.execute("SELECT name FROM apps WHERE app_key = %s", (client_id,))
app_data = cursor.fetchone()
cursor.close()
conn.close()
app_name = app_data[0] if app_data else "未知应用"
# 构建登录页面HTML
login_html = f"""
SSO登录 - {app_name}
{app_name} 请求访问您的账户
测试账号: admin / Admin123456
"""
from fastapi.responses import HTMLResponse
return HTMLResponse(content=login_html)
except Exception as e:
print(f"❌ OAuth登录页面错误: {e}")
return {"error": "server_error", "message": "服务器内部错误"}
@app.get("/oauth/authorize/authenticated")
async def oauth_authorize_authenticated(
response_type: str,
client_id: str,
redirect_uri: str,
access_token: str,
scope: str = "profile",
state: str = None
):
"""用户已登录后的授权处理"""
try:
print(f"🔐 用户已登录,处理授权: client_id={client_id}")
# 验证访问令牌
payload = verify_token(access_token)
if not payload:
error_url = f"{redirect_uri}?error=invalid_token&error_description=Invalid access token"
if state:
error_url += f"&state={state}"
from fastapi.responses import RedirectResponse
return RedirectResponse(url=error_url, status_code=302)
user_id = payload.get("sub")
username = payload.get("username", "")
print(f"✅ 用户已验证: {username} ({user_id})")
# 获取应用信息
conn = get_db_connection()
if not conn:
error_url = f"{redirect_uri}?error=server_error&error_description=Database connection failed"
if state:
error_url += f"&state={state}"
from fastapi.responses import RedirectResponse
return RedirectResponse(url=error_url, status_code=302)
cursor = conn.cursor()
cursor.execute("SELECT name, is_trusted FROM apps WHERE app_key = %s", (client_id,))
app_data = cursor.fetchone()
cursor.close()
conn.close()
if not app_data:
error_url = f"{redirect_uri}?error=invalid_client&error_description=Invalid client"
if state:
error_url += f"&state={state}"
from fastapi.responses import RedirectResponse
return RedirectResponse(url=error_url, status_code=302)
app_name, is_trusted = app_data
# 如果是受信任应用,直接授权
if is_trusted:
# 生成授权码
auth_code = secrets.token_urlsafe(32)
# TODO: 将授权码存储到数据库,关联用户和应用
# 这里简化处理,实际应该存储到数据库
# 重定向回应用
callback_url = f"{redirect_uri}?code={auth_code}"
if state:
callback_url += f"&state={state}"
print(f"✅ 受信任应用自动授权: {callback_url}")
from fastapi.responses import RedirectResponse
return RedirectResponse(url=callback_url, status_code=302)
# 非受信任应用,显示授权确认页面
# 这里可以返回授权确认页面的HTML
# 为简化,暂时也直接授权
auth_code = secrets.token_urlsafe(32)
callback_url = f"{redirect_uri}?code={auth_code}"
if state:
callback_url += f"&state={state}"
print(f"✅ 用户授权完成: {callback_url}")
from fastapi.responses import RedirectResponse
return RedirectResponse(url=callback_url, status_code=302)
except Exception as e:
print(f"❌ 授权处理错误: {e}")
error_url = f"{redirect_uri}?error=server_error&error_description=Authorization failed"
if state:
error_url += f"&state={state}"
from fastapi.responses import RedirectResponse
return RedirectResponse(url=error_url, status_code=302)
async def oauth_approve(
client_id: str,
redirect_uri: str,
scope: str = "profile",
state: str = None
):
"""用户同意授权"""
try:
print(f"✅ 用户同意授权: client_id={client_id}")
# 生成授权码
auth_code = secrets.token_urlsafe(32)
# TODO: 将授权码存储到数据库,关联用户和应用
# 这里简化处理,实际应该:
# 1. 验证用户登录状态
# 2. 将授权码存储到数据库
# 3. 设置过期时间(通常10分钟)
# 构建回调URL
callback_url = f"{redirect_uri}?code={auth_code}"
if state:
callback_url += f"&state={state}"
print(f"🔄 重定向到: {callback_url}")
from fastapi.responses import RedirectResponse
return RedirectResponse(url=callback_url, status_code=302)
except Exception as e:
print(f"❌ 授权确认错误: {e}")
error_url = f"{redirect_uri}?error=server_error&error_description=Authorization failed"
if state:
error_url += f"&state={state}"
from fastapi.responses import RedirectResponse
return RedirectResponse(url=error_url, status_code=302)
@app.get("/oauth/authorize/deny")
async def oauth_deny(
client_id: str,
redirect_uri: str,
state: str = None
):
"""用户拒绝授权"""
try:
print(f"❌ 用户拒绝授权: client_id={client_id}")
# 构建错误回调URL
error_url = f"{redirect_uri}?error=access_denied&error_description=User denied authorization"
if state:
error_url += f"&state={state}"
from fastapi.responses import RedirectResponse
return RedirectResponse(url=error_url, status_code=302)
except Exception as e:
print(f"❌ 拒绝授权错误: {e}")
error_url = f"{redirect_uri}?error=server_error&error_description=Authorization failed"
if state:
error_url += f"&state={state}"
from fastapi.responses import RedirectResponse
return RedirectResponse(url=error_url, status_code=302)
@app.post("/oauth/token")
async def oauth_token(request: Request):
"""OAuth2令牌端点"""
try:
# 获取请求数据
form_data = await request.form()
grant_type = form_data.get("grant_type")
code = form_data.get("code")
redirect_uri = form_data.get("redirect_uri")
client_id = form_data.get("client_id")
client_secret = form_data.get("client_secret")
print(f"🎫 令牌请求: grant_type={grant_type}, client_id={client_id}")
# 验证grant_type
if grant_type != "authorization_code":
return {
"error": "unsupported_grant_type",
"error_description": "Only authorization_code grant type is supported"
}
# 验证必要参数
if not code or not redirect_uri or not client_id:
return {
"error": "invalid_request",
"error_description": "Missing required parameters"
}
# 获取数据库连接
conn = get_db_connection()
if not conn:
return {
"error": "server_error",
"error_description": "Database connection failed"
}
cursor = conn.cursor()
# 验证客户端
cursor.execute("""
SELECT id, name, app_secret, redirect_uris, scope, is_active
FROM apps
WHERE app_key = %s AND is_active = 1
""", (client_id,))
app_data = cursor.fetchone()
if not app_data:
cursor.close()
conn.close()
return {
"error": "invalid_client",
"error_description": "Invalid client credentials"
}
app_id, app_name, stored_secret, redirect_uris_json, scope_json, is_active = app_data
# 验证客户端密钥(如果提供了)
if client_secret and client_secret != stored_secret:
cursor.close()
conn.close()
return {
"error": "invalid_client",
"error_description": "Invalid client credentials"
}
# 验证redirect_uri
redirect_uris = json.loads(redirect_uris_json) if redirect_uris_json else []
if redirect_uri not in redirect_uris:
cursor.close()
conn.close()
return {
"error": "invalid_grant",
"error_description": "Invalid redirect_uri"
}
# TODO: 验证授权码
# 这里简化处理,实际应该:
# 1. 从数据库查找授权码
# 2. 验证授权码是否有效且未过期
# 3. 验证授权码是否已被使用
# 4. 获取关联的用户ID
# 模拟用户ID(实际应该从授权码记录中获取)
user_id = "ed6a79d3-0083-4d81-8b48-fc522f686f74" # admin用户ID
# 生成访问令牌
token_data = {
"sub": user_id,
"client_id": client_id,
"scope": "profile email"
}
access_token = create_access_token(token_data)
refresh_token = secrets.token_urlsafe(32)
# TODO: 将令牌存储到数据库
cursor.close()
conn.close()
# 返回令牌响应
token_response = {
"access_token": access_token,
"token_type": "Bearer",
"expires_in": ACCESS_TOKEN_EXPIRE_MINUTES * 60,
"refresh_token": refresh_token,
"scope": "profile email"
}
print(f"✅ 令牌生成成功: {access_token[:50]}...")
return token_response
except Exception as e:
print(f"❌ 令牌生成错误: {e}")
return {
"error": "server_error",
"error_description": "Internal server error"
}
@app.get("/oauth/userinfo")
async def oauth_userinfo(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""OAuth2用户信息端点"""
try:
# 验证令牌
payload = verify_token(credentials.credentials)
if not payload:
return {
"error": "invalid_token",
"error_description": "Invalid or expired access token"
}
user_id = payload.get("sub")
client_id = payload.get("client_id")
scope = payload.get("scope", "").split()
print(f"👤 用户信息请求: user_id={user_id}, client_id={client_id}, scope={scope}")
# 获取数据库连接
conn = get_db_connection()
if not conn:
return {
"error": "server_error",
"error_description": "Database connection failed"
}
cursor = conn.cursor()
# 查找用户信息
cursor.execute("""
SELECT u.id, u.username, u.email, u.phone, u.avatar_url, u.is_active,
p.real_name, p.company, p.department, p.position
FROM users u
LEFT JOIN user_profiles p ON u.id = p.user_id
WHERE u.id = %s AND u.is_active = 1
""", (user_id,))
user_data = cursor.fetchone()
cursor.close()
conn.close()
if not user_data:
return {
"error": "invalid_token",
"error_description": "User not found or inactive"
}
# 构建用户信息响应(根据scope过滤)
user_info = {"sub": user_data[0]}
if "profile" in scope:
user_info.update({
"username": user_data[1],
"avatar_url": user_data[4],
"real_name": user_data[6],
"company": user_data[7],
"department": user_data[8],
"position": user_data[9]
})
if "email" in scope:
user_info["email"] = user_data[2]
if "phone" in scope:
user_info["phone"] = user_data[3]
print(f"✅ 返回用户信息: {user_info}")
return user_info
except Exception as e:
print(f"❌ 获取用户信息错误: {e}")
return {
"error": "server_error",
"error_description": "Internal server error"
}
@app.get("/api/v1/apps")
async def get_apps(
page: int = 1,
page_size: int = 20,
keyword: str = "",
status: str = "",
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""获取应用列表"""
try:
# 验证令牌
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=200002,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
user_id = payload.get("sub")
# 获取数据库连接
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500001,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 检查用户角色,决定是否显示所有应用
cursor.execute("""
SELECT COUNT(*) FROM user_roles ur
JOIN roles r ON ur.role_id = r.id
WHERE ur.user_id = %s AND r.name IN ('super_admin', 'admin', 'app_manager') AND ur.is_active = 1
""", (user_id,))
is_app_manager = cursor.fetchone()[0] > 0
# 构建查询条件
where_conditions = []
params = []
# 如果不是应用管理员,只显示自己创建的应用
if not is_app_manager:
where_conditions.append("created_by = %s")
params.append(user_id)
if keyword:
where_conditions.append("(name LIKE %s OR description LIKE %s)")
params.extend([f"%{keyword}%", f"%{keyword}%"])
if status == "active":
where_conditions.append("is_active = 1")
elif status == "inactive":
where_conditions.append("is_active = 0")
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
# 查询总数
cursor.execute(f"SELECT COUNT(*) FROM apps WHERE {where_clause}", params)
total = cursor.fetchone()[0]
# 查询应用列表
offset = (page - 1) * page_size
cursor.execute(f"""
SELECT id, name, app_key, description, icon_url, redirect_uris, scope,
is_active, is_trusted, access_token_expires, refresh_token_expires,
created_at, updated_at
FROM apps
WHERE {where_clause}
ORDER BY created_at DESC
LIMIT %s OFFSET %s
""", params + [page_size, offset])
apps = []
for row in cursor.fetchall():
app = {
"id": row[0],
"name": row[1],
"app_key": row[2],
"description": row[3],
"icon_url": row[4],
"redirect_uris": json.loads(row[5]) if row[5] else [],
"scope": json.loads(row[6]) if row[6] else [],
"is_active": bool(row[7]),
"is_trusted": bool(row[8]),
"access_token_expires": row[9],
"refresh_token_expires": row[10],
"created_at": row[11].isoformat() if row[11] else None,
"updated_at": row[12].isoformat() if row[12] else None,
# 模拟统计数据
"today_requests": secrets.randbelow(1000),
"active_users": secrets.randbelow(100)
}
apps.append(app)
cursor.close()
conn.close()
return ApiResponse(
code=0,
message="获取应用列表成功",
data={
"items": apps,
"total": total,
"page": page,
"page_size": page_size
},
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"获取应用列表错误: {e}")
return ApiResponse(
code=500001,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
@app.get("/api/v1/apps/{app_id}")
async def get_app_detail(
app_id: str,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""获取应用详情(包含密钥)"""
try:
# 验证令牌
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=200002,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
user_id = payload.get("sub")
# 获取数据库连接
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500001,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 查询应用详情(包含密钥)
cursor.execute("""
SELECT id, name, app_key, app_secret, description, icon_url,
redirect_uris, scope, is_active, is_trusted,
access_token_expires, refresh_token_expires,
created_at, updated_at
FROM apps
WHERE id = %s AND created_by = %s
""", (app_id, user_id))
app_data = cursor.fetchone()
cursor.close()
conn.close()
if not app_data:
return ApiResponse(
code=200001,
message="应用不存在或无权限",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
app_detail = {
"id": app_data[0],
"name": app_data[1],
"app_key": app_data[2],
"app_secret": app_data[3],
"description": app_data[4],
"icon_url": app_data[5],
"redirect_uris": json.loads(app_data[6]) if app_data[6] else [],
"scope": json.loads(app_data[7]) if app_data[7] else [],
"is_active": bool(app_data[8]),
"is_trusted": bool(app_data[9]),
"access_token_expires": app_data[10],
"refresh_token_expires": app_data[11],
"created_at": app_data[12].isoformat() if app_data[12] else None,
"updated_at": app_data[13].isoformat() if app_data[13] else None
}
return ApiResponse(
code=0,
message="获取应用详情成功",
data=app_detail,
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"获取应用详情错误: {e}")
return ApiResponse(
code=500001,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
@app.post("/api/v1/apps")
async def create_app(
request: Request,
app_data: dict,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""创建应用"""
try:
# 验证令牌
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=200002,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
user_id = payload.get("sub")
# 验证必要字段
if not app_data.get('name') or not app_data.get('redirect_uris'):
return ApiResponse(
code=100001,
message="缺少必要参数",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 获取数据库连接
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500001,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 生成应用ID和密钥
app_id = str(uuid.uuid4())
app_key = generate_random_string(32)
app_secret = generate_random_string(64)
# 插入应用记录
cursor.execute("""
INSERT INTO apps (
id, name, app_key, app_secret, description, icon_url,
redirect_uris, scope, is_active, is_trusted,
access_token_expires, refresh_token_expires, created_by,
created_at, updated_at
) VALUES (
%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), NOW()
)
""", (
app_id,
app_data['name'],
app_key,
app_secret,
app_data.get('description', ''),
app_data.get('icon_url', ''),
json.dumps(app_data['redirect_uris']),
json.dumps(app_data.get('scope', ['profile'])),
True,
app_data.get('is_trusted', False),
app_data.get('access_token_expires', 7200),
app_data.get('refresh_token_expires', 2592000),
user_id
))
conn.commit()
cursor.close()
conn.close()
# 返回创建的应用信息
app_info = {
"id": app_id,
"name": app_data['name'],
"app_key": app_key,
"app_secret": app_secret,
"description": app_data.get('description', ''),
"redirect_uris": app_data['redirect_uris'],
"scope": app_data.get('scope', ['profile']),
"is_active": True,
"is_trusted": app_data.get('is_trusted', False)
}
return ApiResponse(
code=0,
message="应用创建成功",
data=app_info,
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"创建应用错误: {e}")
return ApiResponse(
code=500001,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
@app.put("/api/v1/apps/{app_id}/status")
async def toggle_app_status(
app_id: str,
status_data: dict,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""切换应用状态"""
try:
# 验证令牌
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=200002,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
user_id = payload.get("sub")
is_active = status_data.get('is_active')
if is_active is None:
return ApiResponse(
code=100001,
message="缺少必要参数",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 获取数据库连接
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500001,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 检查应用是否存在且属于当前用户
cursor.execute("""
SELECT id, name FROM apps
WHERE id = %s AND created_by = %s
""", (app_id, user_id))
app_data = cursor.fetchone()
if not app_data:
cursor.close()
conn.close()
return ApiResponse(
code=200001,
message="应用不存在或无权限",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 更新应用状态
cursor.execute("""
UPDATE apps
SET is_active = %s, updated_at = NOW()
WHERE id = %s
""", (is_active, app_id))
conn.commit()
cursor.close()
conn.close()
action = "启用" if is_active else "禁用"
print(f"✅ 应用状态已更新: {app_data[1]} -> {action}")
return ApiResponse(
code=0,
message=f"应用已{action}",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"切换应用状态错误: {e}")
return ApiResponse(
code=500001,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
@app.put("/api/v1/apps/{app_id}")
async def update_app(
app_id: str,
app_data: dict,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""更新应用信息"""
try:
# 验证令牌
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=200002,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
user_id = payload.get("sub")
# 验证必要参数
name = app_data.get('name', '').strip()
if not name:
return ApiResponse(
code=100001,
message="应用名称不能为空",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 获取数据库连接
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500001,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 检查应用是否存在且属于当前用户
cursor.execute("""
SELECT id, name FROM apps
WHERE id = %s AND created_by = %s
""", (app_id, user_id))
existing_app = cursor.fetchone()
if not existing_app:
cursor.close()
conn.close()
return ApiResponse(
code=200001,
message="应用不存在或无权限",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 检查应用名称是否已被其他应用使用
cursor.execute("""
SELECT id FROM apps
WHERE name = %s AND created_by = %s AND id != %s
""", (name, user_id, app_id))
if cursor.fetchone():
cursor.close()
conn.close()
return ApiResponse(
code=200001,
message="应用名称已存在",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 准备更新数据
description = (app_data.get('description') or '').strip()
icon_url = (app_data.get('icon_url') or '').strip()
redirect_uris = app_data.get('redirect_uris', [])
scope = app_data.get('scope', ['profile', 'email'])
is_trusted = app_data.get('is_trusted', False)
access_token_expires = app_data.get('access_token_expires', 7200)
refresh_token_expires = app_data.get('refresh_token_expires', 2592000)
# 验证回调URL
if not redirect_uris or not isinstance(redirect_uris, list):
cursor.close()
conn.close()
return ApiResponse(
code=100001,
message="至少需要一个回调URL",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 验证权限范围
if not scope or not isinstance(scope, list):
scope = ['profile', 'email']
# 更新应用信息
cursor.execute("""
UPDATE apps
SET name = %s, description = %s, icon_url = %s,
redirect_uris = %s, scope = %s, is_trusted = %s,
access_token_expires = %s, refresh_token_expires = %s,
updated_at = NOW()
WHERE id = %s
""", (
name, description, icon_url,
json.dumps(redirect_uris), json.dumps(scope), is_trusted,
access_token_expires, refresh_token_expires, app_id
))
conn.commit()
# 获取更新后的应用信息
cursor.execute("""
SELECT id, name, app_key, description, icon_url,
redirect_uris, scope, is_active, is_trusted,
access_token_expires, refresh_token_expires,
created_at, updated_at
FROM apps
WHERE id = %s
""", (app_id,))
app_info = cursor.fetchone()
cursor.close()
conn.close()
if app_info:
app_result = {
"id": app_info[0],
"name": app_info[1],
"app_key": app_info[2],
"description": app_info[3],
"icon_url": app_info[4],
"redirect_uris": json.loads(app_info[5]) if app_info[5] else [],
"scope": json.loads(app_info[6]) if app_info[6] else [],
"is_active": bool(app_info[7]),
"is_trusted": bool(app_info[8]),
"access_token_expires": app_info[9],
"refresh_token_expires": app_info[10],
"created_at": app_info[11].isoformat() if app_info[11] else None,
"updated_at": app_info[12].isoformat() if app_info[12] else None
}
print(f"✅ 应用已更新: {name}")
return ApiResponse(
code=0,
message="应用更新成功",
data=app_result,
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
else:
return ApiResponse(
code=500001,
message="获取更新后的应用信息失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"更新应用错误: {e}")
return ApiResponse(
code=500001,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
@app.delete("/api/v1/apps/{app_id}")
async def delete_app(
app_id: str,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""删除应用"""
try:
# 验证令牌
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=200002,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
user_id = payload.get("sub")
# 获取数据库连接
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500001,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 检查应用是否存在且属于当前用户
cursor.execute("""
SELECT id, name FROM apps
WHERE id = %s AND created_by = %s
""", (app_id, user_id))
app_data = cursor.fetchone()
if not app_data:
cursor.close()
conn.close()
return ApiResponse(
code=200001,
message="应用不存在或无权限",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 删除应用(级联删除相关数据)
cursor.execute("DELETE FROM apps WHERE id = %s", (app_id,))
conn.commit()
cursor.close()
conn.close()
print(f"✅ 应用已删除: {app_data[1]}")
return ApiResponse(
code=0,
message="应用已删除",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"删除应用错误: {e}")
return ApiResponse(
code=500001,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
@app.post("/api/v1/apps/{app_id}/reset-secret")
async def reset_app_secret(
app_id: str,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""重置应用密钥"""
try:
# 验证令牌
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=200002,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
user_id = payload.get("sub")
# 获取数据库连接
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500001,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 检查应用是否存在且属于当前用户
cursor.execute("""
SELECT id, name FROM apps
WHERE id = %s AND created_by = %s
""", (app_id, user_id))
app_data = cursor.fetchone()
if not app_data:
cursor.close()
conn.close()
return ApiResponse(
code=200001,
message="应用不存在或无权限",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 生成新的应用密钥
new_secret = generate_random_string(64)
# 更新应用密钥
cursor.execute("""
UPDATE apps
SET app_secret = %s, updated_at = NOW()
WHERE id = %s
""", (new_secret, app_id))
conn.commit()
cursor.close()
conn.close()
print(f"✅ 应用密钥已重置: {app_data[1]}")
return ApiResponse(
code=0,
message="应用密钥已重置",
data={"app_secret": new_secret},
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"重置应用密钥错误: {e}")
return ApiResponse(
code=500001,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
def generate_random_string(length=32):
"""生成随机字符串"""
import secrets
import string
alphabet = string.ascii_letters + string.digits
return ''.join(secrets.choice(alphabet) for _ in range(length))
"""获取验证码"""
try:
# 生成验证码
captcha_text, captcha_image = generate_captcha()
# 这里应该将验证码文本存储到缓存中(Redis或内存)
# 为了简化,我们暂时返回固定的验证码
captcha_id = secrets.token_hex(16)
return ApiResponse(
code=0,
message="获取验证码成功",
data={
"captcha_id": captcha_id,
"captcha_image": captcha_image,
"captcha_text": captcha_text # 生产环境中不应该返回这个
},
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"生成验证码错误: {e}")
return ApiResponse(
code=500001,
message="生成验证码失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
def generate_captcha():
"""生成验证码"""
try:
from PIL import Image, ImageDraw, ImageFont
import io
import base64
import random
import string
# 生成随机验证码文本
captcha_text = ''.join(random.choices(string.ascii_uppercase + string.digits, k=4))
# 创建图片
width, height = 120, 40
image = Image.new('RGB', (width, height), color='white')
draw = ImageDraw.Draw(image)
# 尝试使用系统字体,如果失败则使用默认字体
try:
# Windows系统字体
font = ImageFont.truetype("arial.ttf", 20)
except:
try:
# 备用字体
font = ImageFont.truetype("C:/Windows/Fonts/arial.ttf", 20)
except:
# 使用默认字体
font = ImageFont.load_default()
# 绘制验证码文本
text_width = draw.textlength(captcha_text, font=font)
text_height = 20
x = (width - text_width) // 2
y = (height - text_height) // 2
# 添加一些随机颜色
colors = ['#FF6B6B', '#4ECDC4', '#45B7D1', '#96CEB4', '#FFEAA7', '#DDA0DD']
text_color = random.choice(colors)
draw.text((x, y), captcha_text, fill=text_color, font=font)
# 添加一些干扰线
for _ in range(3):
x1 = random.randint(0, width)
y1 = random.randint(0, height)
x2 = random.randint(0, width)
y2 = random.randint(0, height)
draw.line([(x1, y1), (x2, y2)], fill=random.choice(colors), width=1)
# 添加一些干扰点
for _ in range(20):
x = random.randint(0, width)
y = random.randint(0, height)
draw.point((x, y), fill=random.choice(colors))
# 转换为base64
buffer = io.BytesIO()
image.save(buffer, format='PNG')
image_data = buffer.getvalue()
image_base64 = base64.b64encode(image_data).decode('utf-8')
return captcha_text, f"data:image/png;base64,{image_base64}"
except ImportError:
# 如果PIL不可用,返回简单的文本验证码
captcha_text = ''.join(random.choices(string.ascii_uppercase + string.digits, k=4))
# 创建一个简单的SVG验证码
svg_captcha = f"""
"""
svg_base64 = base64.b64encode(svg_captcha.encode('utf-8')).decode('utf-8')
return captcha_text, f"data:image/svg+xml;base64,{svg_base64}"
except Exception as e:
print(f"生成验证码图片失败: {e}")
# 返回默认验证码
return "1234", "data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iMTIwIiBoZWlnaHQ9IjQwIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPjxyZWN0IHdpZHRoPSIxMjAiIGhlaWdodD0iNDAiIGZpbGw9IiNmMGYwZjAiIHN0cm9rZT0iI2NjYyIvPjx0ZXh0IHg9IjYwIiB5PSIyNSIgZm9udC1mYW1pbHk9IkFyaWFsIiBmb250LXNpemU9IjE4IiB0ZXh0LWFuY2hvcj0ibWlkZGxlIiBmaWxsPSIjMzMzIj4xMjM0PC90ZXh0Pjwvc3ZnPg=="
# RBAC权限管理API
@app.get("/api/v1/user/menus")
async def api_get_user_menus(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""获取用户菜单"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=401,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
user_id = payload.get("sub")
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 检查用户是否是超级管理员
cursor.execute("""
SELECT COUNT(*) FROM user_roles ur
JOIN roles r ON ur.role_id = r.id
WHERE ur.user_id = %s AND r.name = 'super_admin' AND ur.is_active = 1
""", (user_id,))
is_super_admin = cursor.fetchone()[0] > 0
if is_super_admin:
# 超级管理员返回所有活跃菜单
cursor.execute("""
SELECT m.id, m.parent_id, m.name, m.title, m.path,
m.component, m.icon, m.sort_order, m.menu_type,
m.is_hidden, m.is_active
FROM menus m
WHERE m.is_active = 1
ORDER BY m.sort_order
""")
else:
# 普通用户根据角色权限获取菜单
cursor.execute("""
SELECT m.id, m.parent_id, m.name, m.title, m.path,
m.component, m.icon, m.sort_order, m.menu_type,
m.is_hidden, m.is_active
FROM menus m
JOIN role_menus rm ON m.id = rm.menu_id
JOIN user_roles ur ON rm.role_id = ur.role_id
WHERE ur.user_id = %s
AND ur.is_active = 1
AND m.is_active = 1
GROUP BY m.id, m.parent_id, m.name, m.title, m.path,
m.component, m.icon, m.sort_order, m.menu_type,
m.is_hidden, m.is_active
ORDER BY m.sort_order
""", (user_id,))
menus = []
for row in cursor.fetchall():
menu_id = str(row[0])
menu_name = str(row[2])
menu_title = str(row[3])
menu_path = str(row[4])
# 只过滤掉明确不想要的“文档处理中心”
# 保留数据库中原本就有的“文档管理中心” (/admin/documents)
if "文档处理中心" in menu_title:
continue
menu = {
"id": row[0],
"parent_id": row[1],
"name": row[2],
"title": row[3],
"path": row[4],
"component": row[5],
"icon": row[6],
"sort_order": row[7],
"menu_type": row[8],
"is_hidden": bool(row[9]),
"is_active": bool(row[10]),
"children": []
}
menus.append(menu)
# 构建菜单树前,过滤掉 button 类型的项,侧边栏只显示 menu 类型
sidebar_menus = [m for m in menus if m.get("menu_type") == "menu"]
menu_tree = build_menu_tree(sidebar_menus)
cursor.close()
conn.close()
return ApiResponse(
code=0,
message="获取用户菜单成功",
data=menu_tree,
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"获取用户菜单错误: {e}")
return ApiResponse(
code=500,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
def build_menu_tree(menus):
"""构建菜单树结构"""
menu_map = {menu["id"]: menu for menu in menus}
tree = []
for menu in menus:
if menu["parent_id"] is None:
tree.append(menu)
else:
parent = menu_map.get(menu["parent_id"])
if parent:
parent["children"].append(menu)
return tree
@app.get("/api/v1/admin/menus")
async def api_get_all_menus(
page: int = 1,
page_size: int = 1000, # 增大默认页面大小,确保返回所有菜单
keyword: Optional[str] = None,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""获取所有菜单(管理员)"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=401,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 简化权限检查 - 只检查是否为管理员
is_superuser = payload.get("is_superuser", False)
if not is_superuser:
return ApiResponse(
code=403,
message="权限不足",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 构建查询条件
where_conditions = []
params = []
if keyword:
where_conditions.append("(m.title LIKE %s OR m.name LIKE %s)")
params.extend([f"%{keyword}%", f"%{keyword}%"])
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
# 查询总数
cursor.execute(f"SELECT COUNT(*) FROM menus m WHERE {where_clause}", params)
total = cursor.fetchone()[0]
# 查询菜单列表 - 修改排序逻辑以支持树形结构
cursor.execute(f"""
SELECT m.id, m.parent_id, m.name, m.title, m.path, m.component,
m.icon, m.sort_order, m.menu_type, m.is_hidden, m.is_active,
m.description, m.created_at, m.updated_at,
pm.title as parent_title
FROM menus m
LEFT JOIN menus pm ON m.parent_id = pm.id
WHERE {where_clause}
ORDER BY
CASE WHEN m.parent_id IS NULL THEN 0 ELSE 1 END,
m.sort_order,
CASE WHEN m.menu_type = 'menu' THEN 0 ELSE 1 END,
m.created_at
LIMIT %s OFFSET %s
""", params + [page_size, (page - 1) * page_size])
menus = []
for row in cursor.fetchall():
menu = {
"id": row[0],
"parent_id": row[1],
"name": row[2],
"title": row[3],
"path": row[4],
"component": row[5],
"icon": row[6],
"sort_order": row[7],
"menu_type": row[8],
"is_hidden": bool(row[9]),
"is_active": bool(row[10]),
"description": row[11],
"created_at": row[12].isoformat() if row[12] else None,
"updated_at": row[13].isoformat() if row[13] else None,
"parent_title": row[14]
}
menus.append(menu)
cursor.close()
conn.close()
return ApiResponse(
code=0,
message="获取菜单列表成功",
data={
"items": menus,
"total": total,
"page": page,
"page_size": page_size
},
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"获取菜单列表错误: {e}")
return ApiResponse(
code=500,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
@app.get("/api/v1/admin/roles")
async def api_get_all_roles(
page: int = 1,
page_size: int = 20,
keyword: Optional[str] = None,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""获取所有角色"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=401,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 简化权限检查 - 只检查是否为管理员
is_superuser = payload.get("is_superuser", False)
if not is_superuser:
return ApiResponse(
code=403,
message="权限不足",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 构建查询条件
where_conditions = []
params = []
if keyword:
where_conditions.append("(r.display_name LIKE %s OR r.name LIKE %s)")
params.extend([f"%{keyword}%", f"%{keyword}%"])
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
# 查询总数
cursor.execute(f"SELECT COUNT(*) FROM roles r WHERE {where_clause}", params)
total = cursor.fetchone()[0]
# 查询角色列表
offset = (page - 1) * page_size
cursor.execute(f"""
SELECT r.id, r.name, r.display_name, r.description, r.is_active,
r.is_system, r.created_at, r.updated_at,
COUNT(ur.user_id) as user_count
FROM roles r
LEFT JOIN user_roles ur ON r.id = ur.role_id AND ur.is_active = 1
WHERE {where_clause}
GROUP BY r.id
ORDER BY r.is_system DESC, r.created_at
LIMIT %s OFFSET %s
""", params + [page_size, offset])
roles = []
for row in cursor.fetchall():
role = {
"id": row[0],
"name": row[1],
"display_name": row[2],
"description": row[3],
"is_active": bool(row[4]),
"is_system": bool(row[5]),
"created_at": row[6].isoformat() if row[6] else None,
"updated_at": row[7].isoformat() if row[7] else None,
"user_count": row[8]
}
roles.append(role)
cursor.close()
conn.close()
return ApiResponse(
code=0,
message="获取角色列表成功",
data={
"items": roles,
"total": total,
"page": page,
"page_size": page_size
},
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"获取角色列表错误: {e}")
return ApiResponse(
code=500,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
@app.get("/api/v1/user/permissions")
async def api_get_user_permissions(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""获取用户权限"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=401,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
user_id = payload.get("sub")
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 获取用户权限
cursor.execute("""
SELECT DISTINCT p.name, p.resource, p.action
FROM permissions p
JOIN role_permissions rp ON p.id = rp.permission_id
JOIN user_roles ur ON rp.role_id = ur.role_id
WHERE ur.user_id = %s
AND ur.is_active = 1
AND p.is_active = 1
""", (user_id,))
permissions = []
for row in cursor.fetchall():
permissions.append({
"name": row[0],
"resource": row[1],
"action": row[2]
})
cursor.close()
conn.close()
return ApiResponse(
code=0,
message="获取用户权限成功",
data=permissions,
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"获取用户权限错误: {e}")
return ApiResponse(
code=500,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 用户管理API
@app.get("/api/v1/admin/users")
async def get_users(
page: int = 1,
page_size: int = 20,
keyword: Optional[str] = None,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""获取用户列表"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
is_superuser = payload.get("is_superuser", False)
if not is_superuser:
return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
cursor = conn.cursor()
# 构建查询条件
where_conditions = []
params = []
if keyword:
where_conditions.append("(u.username LIKE %s OR u.email LIKE %s OR up.real_name LIKE %s)")
params.extend([f"%{keyword}%", f"%{keyword}%", f"%{keyword}%"])
where_clause = " AND ".join(where_conditions) if where_conditions else "1=1"
# 查询总数
cursor.execute(f"SELECT COUNT(*) FROM users u LEFT JOIN user_profiles up ON u.id = up.user_id WHERE {where_clause}", params)
total = cursor.fetchone()[0]
# 查询用户列表
offset = (page - 1) * page_size
cursor.execute(f"""
SELECT u.id, u.username, u.email, u.phone, u.is_active, u.is_superuser,
u.last_login_at, u.created_at, up.real_name, up.company, up.department,
GROUP_CONCAT(r.display_name) as roles
FROM users u
LEFT JOIN user_profiles up ON u.id = up.user_id
LEFT JOIN user_roles ur ON u.id = ur.user_id AND ur.is_active = 1
LEFT JOIN roles r ON ur.role_id = r.id
WHERE {where_clause}
GROUP BY u.id, u.username, u.email, u.phone, u.is_active, u.is_superuser,
u.last_login_at, u.created_at, up.real_name, up.company, up.department
ORDER BY u.created_at DESC
LIMIT %s OFFSET %s
""", params + [page_size, offset])
users = []
for row in cursor.fetchall():
users.append({
"id": row[0],
"username": row[1],
"email": row[2],
"phone": row[3],
"is_active": bool(row[4]),
"is_superuser": bool(row[5]),
"last_login_at": row[6].isoformat() if row[6] else None,
"created_at": row[7].isoformat() if row[7] else None,
"real_name": row[8],
"company": row[9],
"department": row[10],
"roles": row[11].split(',') if row[11] else []
})
cursor.close()
conn.close()
return ApiResponse(
code=0,
message="获取用户列表成功",
data={"items": users, "total": total, "page": page, "page_size": page_size},
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"获取用户列表错误: {e}")
return ApiResponse(code=500, message="服务器内部错误", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
@app.post("/api/v1/admin/users")
async def create_user(
user_data: dict,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""创建用户"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
is_superuser = payload.get("is_superuser", False)
if not is_superuser:
return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
cursor = conn.cursor()
# 检查用户名和邮箱是否已存在
cursor.execute("SELECT id FROM users WHERE username = %s OR email = %s",
(user_data['username'], user_data['email']))
if cursor.fetchone():
cursor.close()
conn.close()
return ApiResponse(code=400, message="用户名或邮箱已存在", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
# 生成用户ID
user_id = str(uuid.uuid4())
# 创建密码哈希
password_hash = hash_password_simple(user_data['password'])
# 插入用户
cursor.execute("""
INSERT INTO users (id, username, email, phone, password_hash, is_active, is_superuser, created_at, updated_at)
VALUES (%s, %s, %s, %s, %s, %s, %s, NOW(), NOW())
""", (user_id, user_data['username'], user_data['email'], user_data.get('phone'),
password_hash, user_data.get('is_active', True), user_data.get('is_superuser', False)))
# 插入用户详情
if any(key in user_data for key in ['real_name', 'company', 'department']):
profile_id = str(uuid.uuid4())
cursor.execute("""
INSERT INTO user_profiles (id, user_id, real_name, company, department, created_at, updated_at)
VALUES (%s, %s, %s, %s, %s, NOW(), NOW())
""", (profile_id, user_id, user_data.get('real_name'), user_data.get('company'), user_data.get('department')))
# 分配角色
if 'role_ids' in user_data and user_data['role_ids']:
for role_id in user_data['role_ids']:
role_assignment_id = str(uuid.uuid4())
cursor.execute("""
INSERT INTO user_roles (id, user_id, role_id, assigned_by, created_at)
VALUES (%s, %s, %s, %s, NOW())
""", (role_assignment_id, user_id, role_id, payload.get("sub")))
conn.commit()
cursor.close()
conn.close()
return ApiResponse(code=0, message="用户创建成功", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
except Exception as e:
print(f"创建用户错误: {e}")
return ApiResponse(code=500, message="服务器内部错误", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
@app.put("/api/v1/admin/users/{user_id}")
async def update_user(
user_id: str,
user_data: dict,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""更新用户"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
is_superuser = payload.get("is_superuser", False)
if not is_superuser:
return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
cursor = conn.cursor()
# 更新用户基本信息
update_fields = []
update_values = []
for field in ['email', 'phone', 'is_active', 'is_superuser']:
if field in user_data:
update_fields.append(f'{field} = %s')
update_values.append(user_data[field])
if update_fields:
update_values.append(user_id)
cursor.execute(f"""
UPDATE users
SET {', '.join(update_fields)}, updated_at = NOW()
WHERE id = %s
""", update_values)
# 更新用户详情
profile_fields = ['real_name', 'company', 'department']
profile_updates = {k: v for k, v in user_data.items() if k in profile_fields}
if profile_updates:
# 检查是否已有记录
cursor.execute("SELECT id FROM user_profiles WHERE user_id = %s", (user_id,))
profile_exists = cursor.fetchone()
if profile_exists:
update_fields = []
update_values = []
for field, value in profile_updates.items():
update_fields.append(f'{field} = %s')
update_values.append(value)
update_values.append(user_id)
cursor.execute(f"""
UPDATE user_profiles
SET {', '.join(update_fields)}, updated_at = NOW()
WHERE user_id = %s
""", update_values)
else:
profile_id = str(uuid.uuid4())
fields = ['id', 'user_id'] + list(profile_updates.keys())
values = [profile_id, user_id] + list(profile_updates.values())
placeholders = ', '.join(['%s'] * len(values))
cursor.execute(f"""
INSERT INTO user_profiles ({', '.join(fields)}, created_at, updated_at)
VALUES ({placeholders}, NOW(), NOW())
""", values)
# 更新用户角色
if 'role_ids' in user_data:
# 删除现有角色
cursor.execute("DELETE FROM user_roles WHERE user_id = %s", (user_id,))
# 添加新角色
for role_id in user_data['role_ids']:
assignment_id = str(uuid.uuid4())
cursor.execute("""
INSERT INTO user_roles (id, user_id, role_id, assigned_by, created_at)
VALUES (%s, %s, %s, %s, NOW())
""", (assignment_id, user_id, role_id, payload.get("sub")))
conn.commit()
cursor.close()
conn.close()
return ApiResponse(code=0, message="用户更新成功", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
except Exception as e:
print(f"更新用户错误: {e}")
return ApiResponse(code=500, message="服务器内部错误", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
@app.delete("/api/v1/admin/users/{user_id}")
async def delete_user(
user_id: str,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""删除用户"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
is_superuser = payload.get("is_superuser", False)
if not is_superuser:
return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
# 不能删除自己
if user_id == payload.get("sub"):
return ApiResponse(code=400, message="不能删除自己", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
cursor = conn.cursor()
# 检查是否为超级管理员
cursor.execute("""
SELECT COUNT(*) FROM user_roles ur
JOIN roles r ON ur.role_id = r.id
WHERE ur.user_id = %s AND r.name = 'super_admin' AND ur.is_active = 1
""", (user_id,))
if cursor.fetchone()[0] > 0:
cursor.close()
conn.close()
return ApiResponse(code=400, message="不能删除超级管理员", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
# 删除相关数据
cursor.execute("DELETE FROM user_roles WHERE user_id = %s", (user_id,))
cursor.execute("DELETE FROM user_profiles WHERE user_id = %s", (user_id,))
cursor.execute("DELETE FROM users WHERE id = %s", (user_id,))
conn.commit()
cursor.close()
conn.close()
return ApiResponse(code=0, message="用户删除成功", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
except Exception as e:
print(f"删除用户错误: {e}")
return ApiResponse(code=500, message="服务器内部错误", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
# 角色管理API
@app.post("/api/v1/admin/roles")
async def create_role(
role_data: dict,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""创建角色"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
is_superuser = payload.get("is_superuser", False)
if not is_superuser:
return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
cursor = conn.cursor()
# 检查角色名是否已存在
cursor.execute("SELECT id FROM roles WHERE name = %s", (role_data['name'],))
if cursor.fetchone():
cursor.close()
conn.close()
return ApiResponse(code=400, message="角色名已存在", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
# 创建角色
role_id = str(uuid.uuid4())
cursor.execute("""
INSERT INTO roles (id, name, display_name, description, is_active, is_system, created_at, updated_at)
VALUES (%s, %s, %s, %s, %s, %s, NOW(), NOW())
""", (role_id, role_data['name'], role_data['display_name'], role_data.get('description'),
role_data.get('is_active', True), False))
conn.commit()
cursor.close()
conn.close()
return ApiResponse(code=0, message="角色创建成功", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
except Exception as e:
print(f"创建角色错误: {e}")
return ApiResponse(code=500, message="服务器内部错误", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
@app.put("/api/v1/admin/roles/{role_id}")
async def update_role(
role_id: str,
role_data: dict,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""更新角色"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
is_superuser = payload.get("is_superuser", False)
if not is_superuser:
return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
cursor = conn.cursor()
# 检查是否为系统角色
cursor.execute("SELECT is_system FROM roles WHERE id = %s", (role_id,))
role = cursor.fetchone()
if not role:
cursor.close()
conn.close()
return ApiResponse(code=404, message="角色不存在", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
if role[0]: # is_system
cursor.close()
conn.close()
return ApiResponse(code=400, message="不能修改系统角色", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
# 更新角色
update_fields = []
update_values = []
for field in ['display_name', 'description', 'is_active']:
if field in role_data:
update_fields.append(f'{field} = %s')
update_values.append(role_data[field])
if update_fields:
update_values.append(role_id)
cursor.execute(f"""
UPDATE roles
SET {', '.join(update_fields)}, updated_at = NOW()
WHERE id = %s
""", update_values)
conn.commit()
cursor.close()
conn.close()
return ApiResponse(code=0, message="角色更新成功", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
except Exception as e:
print(f"更新角色错误: {e}")
return ApiResponse(code=500, message="服务器内部错误", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
@app.delete("/api/v1/admin/roles/{role_id}")
async def delete_role(
role_id: str,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""删除角色"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
is_superuser = payload.get("is_superuser", False)
if not is_superuser:
return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
cursor = conn.cursor()
# 检查是否为系统角色
cursor.execute("SELECT is_system FROM roles WHERE id = %s", (role_id,))
role = cursor.fetchone()
if not role:
cursor.close()
conn.close()
return ApiResponse(code=404, message="角色不存在", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
if role[0]: # is_system
cursor.close()
conn.close()
return ApiResponse(code=400, message="不能删除系统角色", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
# 检查是否有用户使用此角色
cursor.execute("SELECT COUNT(*) FROM user_roles WHERE role_id = %s", (role_id,))
if cursor.fetchone()[0] > 0:
cursor.close()
conn.close()
return ApiResponse(code=400, message="该角色正在被使用,无法删除", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
# 删除角色相关数据
cursor.execute("DELETE FROM role_permissions WHERE role_id = %s", (role_id,))
cursor.execute("DELETE FROM role_menus WHERE role_id = %s", (role_id,))
cursor.execute("DELETE FROM roles WHERE id = %s", (role_id,))
conn.commit()
cursor.close()
conn.close()
return ApiResponse(code=0, message="角色删除成功", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
except Exception as e:
print(f"删除角色错误: {e}")
return ApiResponse(code=500, message="服务器内部错误", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
# 角色菜单权限管理API
@app.get("/api/v1/admin/roles/{role_id}/menus")
async def get_role_menus(
role_id: str,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""获取角色的菜单权限"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=401,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 检查管理员权限
is_superuser = payload.get("is_superuser", False)
if not is_superuser:
return ApiResponse(
code=403,
message="权限不足",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 检查角色是否存在
cursor.execute("SELECT id, name FROM roles WHERE id = %s", (role_id,))
role = cursor.fetchone()
if not role:
cursor.close()
conn.close()
return ApiResponse(
code=404,
message="角色不存在",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 检查是否为超级管理员角色
role_name = role[1]
is_super_admin_role = role_name == "super_admin"
if is_super_admin_role:
# 超级管理员默认拥有所有菜单权限
cursor.execute("""
SELECT id, name, title, parent_id, menu_type
FROM menus
WHERE is_active = 1
ORDER BY sort_order
""")
menu_permissions = cursor.fetchall()
else:
# 普通角色查询已分配的菜单权限
cursor.execute("""
SELECT m.id, m.name, m.title, m.parent_id, m.menu_type
FROM role_menus rm
JOIN menus m ON rm.menu_id = m.id
WHERE rm.role_id = %s AND m.is_active = 1
ORDER BY m.sort_order
""", (role_id,))
menu_permissions = cursor.fetchall()
cursor.close()
conn.close()
# 构建返回数据
menu_ids = [menu[0] for menu in menu_permissions]
menu_details = []
for menu in menu_permissions:
menu_details.append({
"id": menu[0],
"name": menu[1],
"title": menu[2],
"parent_id": menu[3],
"menu_type": menu[4]
})
return ApiResponse(
code=0,
message="获取角色菜单权限成功",
data={
"role_id": role_id,
"role_name": role[1],
"menu_ids": menu_ids,
"menu_details": menu_details,
"total": len(menu_ids)
},
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"获取角色菜单权限错误: {e}")
return ApiResponse(
code=500,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
@app.put("/api/v1/admin/roles/{role_id}/menus")
async def update_role_menus(
role_id: str,
request: Request,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""更新角色的菜单权限"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(
code=401,
message="无效的访问令牌",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 检查管理员权限
is_superuser = payload.get("is_superuser", False)
if not is_superuser:
return ApiResponse(
code=403,
message="权限不足",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 获取请求数据
body = await request.json()
menu_ids = body.get("menu_ids", [])
if not isinstance(menu_ids, list):
return ApiResponse(
code=400,
message="菜单ID列表格式错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(
code=500,
message="数据库连接失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
cursor = conn.cursor()
# 检查角色是否存在
cursor.execute("SELECT id, name FROM roles WHERE id = %s", (role_id,))
role = cursor.fetchone()
if not role:
cursor.close()
conn.close()
return ApiResponse(
code=404,
message="角色不存在",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 检查是否为超级管理员角色
role_name = role[1]
is_super_admin_role = role_name == "super_admin"
if is_super_admin_role:
# 超级管理员角色不允许修改权限,始终拥有全部权限
cursor.close()
conn.close()
return ApiResponse(
code=400,
message="超级管理员角色拥有全部权限,无需修改",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 验证菜单ID是否存在
if menu_ids:
placeholders = ','.join(['%s'] * len(menu_ids))
cursor.execute(f"""
SELECT id FROM menus
WHERE id IN ({placeholders}) AND is_active = 1
""", menu_ids)
valid_menu_ids = [row[0] for row in cursor.fetchall()]
invalid_menu_ids = set(menu_ids) - set(valid_menu_ids)
if invalid_menu_ids:
cursor.close()
conn.close()
return ApiResponse(
code=400,
message=f"无效的菜单ID: {', '.join(invalid_menu_ids)}",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 开始事务
cursor.execute("START TRANSACTION")
try:
# 删除角色现有的菜单权限
cursor.execute("DELETE FROM role_menus WHERE role_id = %s", (role_id,))
# 添加新的菜单权限
if menu_ids:
values = [(role_id, menu_id) for menu_id in menu_ids]
cursor.executemany("""
INSERT INTO role_menus (role_id, menu_id, created_at)
VALUES (%s, %s, NOW())
""", values)
# 提交事务
conn.commit()
cursor.close()
conn.close()
return ApiResponse(
code=0,
message="角色菜单权限更新成功",
data={
"role_id": role_id,
"role_name": role[1],
"menu_ids": menu_ids,
"updated_count": len(menu_ids)
},
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
# 回滚事务
conn.rollback()
cursor.close()
conn.close()
raise e
except Exception as e:
print(f"更新角色菜单权限错误: {e}")
return ApiResponse(
code=500,
message="服务器内部错误",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
# 菜单管理API
@app.post("/api/v1/admin/menus")
async def create_menu(
menu_data: dict,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""创建菜单"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
is_superuser = payload.get("is_superuser", False)
if not is_superuser:
return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
cursor = conn.cursor()
# 检查菜单名是否已存在
cursor.execute("SELECT id FROM menus WHERE name = %s", (menu_data['name'],))
if cursor.fetchone():
cursor.close()
conn.close()
return ApiResponse(code=400, message="菜单标识已存在", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
# 创建菜单
menu_id = str(uuid.uuid4())
cursor.execute("""
INSERT INTO menus (id, parent_id, name, title, path, component, icon,
sort_order, menu_type, is_hidden, is_active, description, created_at, updated_at)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, NOW(), NOW())
""", (
menu_id, menu_data.get('parent_id'), menu_data['name'], menu_data['title'],
menu_data.get('path'), menu_data.get('component'), menu_data.get('icon'),
menu_data.get('sort_order', 0), menu_data.get('menu_type', 'menu'),
menu_data.get('is_hidden', False), menu_data.get('is_active', True),
menu_data.get('description')
))
conn.commit()
cursor.close()
conn.close()
return ApiResponse(code=0, message="菜单创建成功", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
except Exception as e:
print(f"创建菜单错误: {e}")
return ApiResponse(code=500, message="服务器内部错误", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
@app.put("/api/v1/admin/menus/{menu_id}")
async def update_menu(
menu_id: str,
menu_data: dict,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""更新菜单"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
is_superuser = payload.get("is_superuser", False)
if not is_superuser:
return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
cursor = conn.cursor()
# 更新菜单
update_fields = []
update_values = []
for field in ['parent_id', 'title', 'path', 'component', 'icon', 'sort_order',
'menu_type', 'is_hidden', 'is_active', 'description']:
if field in menu_data:
update_fields.append(f'{field} = %s')
update_values.append(menu_data[field])
if update_fields:
update_values.append(menu_id)
cursor.execute(f"""
UPDATE menus
SET {', '.join(update_fields)}, updated_at = NOW()
WHERE id = %s
""", update_values)
conn.commit()
cursor.close()
conn.close()
return ApiResponse(code=0, message="菜单更新成功", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
except Exception as e:
print(f"更新菜单错误: {e}")
return ApiResponse(code=500, message="服务器内部错误", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
@app.delete("/api/v1/admin/menus/{menu_id}")
async def delete_menu(
menu_id: str,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""删除菜单"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
is_superuser = payload.get("is_superuser", False)
if not is_superuser:
return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
cursor = conn.cursor()
# 检查是否有子菜单
cursor.execute("SELECT COUNT(*) FROM menus WHERE parent_id = %s", (menu_id,))
if cursor.fetchone()[0] > 0:
cursor.close()
conn.close()
return ApiResponse(code=400, message="该菜单下有子菜单,无法删除", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
# 删除菜单相关数据
cursor.execute("DELETE FROM role_menus WHERE menu_id = %s", (menu_id,))
cursor.execute("DELETE FROM menus WHERE id = %s", (menu_id,))
conn.commit()
cursor.close()
conn.close()
return ApiResponse(code=0, message="菜单删除成功", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
except Exception as e:
print(f"删除菜单错误: {e}")
return ApiResponse(code=500, message="服务器内部错误", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
# 获取所有角色(用于下拉选择)
@app.get("/api/v1/roles/all")
async def get_all_roles_simple(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""获取所有角色(简化版,用于下拉选择)"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
cursor = conn.cursor()
cursor.execute("""
SELECT id, name, display_name, is_system, is_active
FROM roles
WHERE is_active = 1
ORDER BY is_system DESC, display_name
""")
roles = []
for row in cursor.fetchall():
roles.append({
"id": row[0],
"name": row[1],
"display_name": row[2],
"is_system": bool(row[3]),
"is_active": bool(row[4])
})
cursor.close()
conn.close()
return ApiResponse(code=0, message="获取角色列表成功", data=roles, timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
except Exception as e:
print(f"获取角色列表错误: {e}")
return ApiResponse(code=500, message="服务器内部错误", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
import httpx
from fastapi.responses import HTMLResponse
class BatchEnterRequest(BaseModel):
ids: list[Union[int, str]]
table_type: Optional[str] = None
class BatchDeleteRequest(BaseModel):
ids: list[Union[int, str]]
table_type: Optional[str] = None
class ConvertRequest(BaseModel):
id: Union[int, str]
table_type: Optional[str] = None
# --- 文档管理中心 API ---
@app.get("/api/v1/documents/proxy-view")
async def proxy_view(url: str, token: Optional[str] = None, credentials: Optional[HTTPAuthorizationCredentials] = Depends(security_optional)):
"""抓取外部文档内容并返回,支持 HTML 和 PDF 等二进制文件。支持从 Header 或 Query 参数获取 Token。"""
try:
# 优先从 Header 获取,如果没有则从参数获取
actual_token = None
if credentials:
actual_token = credentials.credentials
elif token:
actual_token = token
if not actual_token:
return ApiResponse(code=401, message="未提供认证令牌", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
payload = verify_token(actual_token)
if not payload or not payload.get("is_superuser"):
return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
# 增加超时时间,支持大文件下载
async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
}
response = await client.get(url, headers=headers)
response.raise_for_status()
content_type = response.headers.get("content-type", "").lower()
# 如果是 PDF 或其他二进制文件
if "application/pdf" in content_type or any(ext in url.lower() for ext in [".pdf", ".png", ".jpg", ".jpeg", ".gif"]):
return Response(
content=response.content,
media_type=content_type,
headers={"Content-Disposition": "inline"}
)
# 默认处理为 HTML
try:
content = response.text
# 简单的注入一些基础样式,确保内容在 iframe 中显示良好
base_style = """
"""
if "" in content:
content = content.replace("", f"{base_style}")
else:
content = f"{base_style}{content}"
return HTMLResponse(content=content)
except Exception:
# 如果文本解析失败,返回原始字节
return Response(content=response.content, media_type=content_type)
except Exception as e:
error_msg = f"无法加载内容
错误原因: {str(e)}
URL: {url}
"
return HTMLResponse(content=error_msg, status_code=500)
@app.post("/api/v1/documents/batch-enter")
async def batch_enter_knowledge_base(req: BatchEnterRequest, credentials: HTTPAuthorizationCredentials = Depends(security)):
"""批量将文档加入知识库"""
try:
payload = verify_token(credentials.credentials)
if not payload or not payload.get("is_superuser"):
return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
cursor = conn.cursor()
# 1. 批量更新主表 t_document_main
# 只更新尚未入库的数据 (whether_to_enter = 0),同时更新修改时间和修改人(如果需要)
placeholders = ', '.join(['%s'] * len(req.ids))
username = payload.get("username", "admin")
# 首先更新主表
sql_main = f"""
UPDATE t_document_main
SET whether_to_enter = 1, updated_time = NOW()
WHERE id IN ({placeholders}) AND whether_to_enter = 0
"""
cursor.execute(sql_main, req.ids)
affected_rows = cursor.rowcount
# 2. 尝试更新对应的子表以保持同步
try:
# 查询这些 ID 对应的 source_type 和 source_id
cursor.execute(f"SELECT id, source_type, source_id FROM t_document_main WHERE id IN ({placeholders})", req.ids)
docs = cursor.fetchall()
for doc_row in docs:
d_id, s_type, s_id = doc_row
if s_type and s_id:
sub_table = get_table_name(s_type)
if sub_table:
# 更新子表中的 whether_to_enter 字段(如果存在)
# 注意:子表中的主键可能是 id 且值为 s_id
sub_sql = f"UPDATE {sub_table} SET whether_to_enter = 1, updated_at = NOW(), updated_by = %s WHERE id = %s"
try:
cursor.execute(sub_sql, (username, s_id))
except Exception as sub_e:
print(f"更新子表 {sub_table} 失败 (可能字段不存在): {sub_e}")
except Exception as sync_e:
print(f"同步更新子表失败: {sync_e}")
conn.commit()
cursor.close()
conn.close()
message = f"成功将 {affected_rows} 条数据加入知识库"
if affected_rows < len(req.ids):
message += f"(跳过了 {len(req.ids) - affected_rows} 条已入库数据或未找到数据)"
return ApiResponse(code=0, message=message, timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
except Exception as e:
print(f"批量操作失败: {e}")
return ApiResponse(code=500, message=f"批量操作失败: {str(e)}", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
@app.post("/api/v1/documents/batch-delete")
async def batch_delete_documents(req: BatchDeleteRequest, credentials: HTTPAuthorizationCredentials = Depends(security)):
"""批量删除文档"""
conn = None
cursor = None
try:
payload = verify_token(credentials.credentials)
if not payload or not payload.get("is_superuser"):
return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
cursor = conn.cursor()
if not req.ids:
return ApiResponse(code=400, message="未指定要删除的文档 ID", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
placeholders = ', '.join(['%s'] * len(req.ids))
# 1. 尝试同步删除子表中的数据
try:
# 查询这些 ID 对应的 source_type 和 source_id
cursor.execute(f"SELECT source_type, source_id FROM t_document_main WHERE id IN ({placeholders})", req.ids)
docs = cursor.fetchall()
for doc_row in docs:
s_type, s_id = doc_row
if s_type and s_id:
sub_table = get_table_name(s_type)
if sub_table:
# 删除子表数据
sub_sql = f"DELETE FROM {sub_table} WHERE id = %s"
try:
cursor.execute(sub_sql, (s_id,))
except Exception as sub_e:
print(f"删除子表 {sub_table} 数据失败: {sub_e}")
except Exception as sync_e:
print(f"同步删除子表数据失败: {sync_e}")
# 2. 删除主表 t_document_main 中的数据
sql_main = f"DELETE FROM t_document_main WHERE id IN ({placeholders})"
cursor.execute(sql_main, req.ids)
affected_rows = cursor.rowcount
conn.commit()
return ApiResponse(
code=0,
message=f"成功删除 {affected_rows} 条文档数据",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"批量删除失败: {e}")
return ApiResponse(code=500, message=f"批量删除失败: {str(e)}", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
finally:
if cursor:
cursor.close()
if conn:
conn.close()
async def simulate_conversion(doc_id: str):
"""模拟文档转换过程"""
import time
conn = None
try:
conn = get_db_connection()
cursor = conn.cursor()
# 1. 模拟开始 (10%)
cursor.execute("UPDATE t_document_main SET conversion_status = 1, conversion_progress = 10 WHERE id = %s", (doc_id,))
conn.commit()
time.sleep(2)
# 2. 模拟进行中 (40%)
cursor.execute("UPDATE t_document_main SET conversion_progress = 40 WHERE id = %s", (doc_id,))
conn.commit()
time.sleep(3)
# 3. 模拟进行中 (75%)
cursor.execute("UPDATE t_document_main SET conversion_progress = 75 WHERE id = %s", (doc_id,))
conn.commit()
time.sleep(2)
# 4. 模拟完成 (100%)
cursor.execute("""
UPDATE t_document_main
SET conversion_status = 2, conversion_progress = 100,
converted_file_name = CONCAT(title, '_已转换.pdf')
WHERE id = %s
""", (doc_id,))
conn.commit()
except Exception as e:
print(f"模拟转换出错: {e}")
if conn:
cursor = conn.cursor()
cursor.execute("UPDATE t_document_main SET conversion_status = 3, conversion_error = %s WHERE id = %s", (str(e), doc_id))
conn.commit()
finally:
if conn:
conn.close()
@app.post("/api/v1/documents/convert")
async def convert_document(req: ConvertRequest, background_tasks: BackgroundTasks, credentials: HTTPAuthorizationCredentials = Depends(security)):
"""启动文档转换 (支持真实脚本与模拟逻辑)"""
try:
payload = verify_token(credentials.credentials)
if not payload or not payload.get("is_superuser"):
return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
table_type = req.table_type
# 如果没有提供 table_type,从主表查询
if not table_type:
try:
conn = get_db_connection()
if conn:
cursor = conn.cursor()
cursor.execute("SELECT source_type FROM t_document_main WHERE id = %s", (req.id,))
res = cursor.fetchone()
if res:
table_type = res[0]
cursor.close()
conn.close()
except Exception as e:
print(f"从主表获取 source_type 失败: {e}")
# 1. 优先尝试启动真实转换脚本
script_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "scripts", "miner_u.py"))
if os.path.exists(script_path):
import subprocess
python_exe = sys.executable
# 传递 table_type 和 id 给脚本
subprocess.Popen([python_exe, script_path, str(table_type or "basis"), str(req.id)],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
creationflags=subprocess.CREATE_NO_WINDOW if os.name == 'nt' else 0)
return ApiResponse(code=0, message="转换任务已在后台启动", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
# 2. 如果脚本不存在,则启动模拟转换逻辑
background_tasks.add_task(simulate_conversion, str(req.id))
return ApiResponse(
code=0,
message="转换任务已启动 (模拟模式)",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"启动转换失败: {e}")
return ApiResponse(code=500, message=f"启动转换失败: {str(e)}", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
@app.post("/api/v1/documents/add")
async def add_document(doc: DocumentAdd, credentials: HTTPAuthorizationCredentials = Depends(security)):
"""添加新文档 (同步主表和子表)"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
user_id = payload.get("username", "admin")
conn = get_db_connection()
if not conn:
return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
cursor = conn.cursor()
doc_id = str(uuid.uuid4())
source_id = str(uuid.uuid4())
table_name = TABLE_MAP.get(doc.table_type, "t_basis_of_preparation")
try:
# 1. 插入子表
if doc.table_type == 'basis':
cursor.execute(
f"INSERT INTO {table_name} (id, chinese_name, created_by) VALUES (%s, %s, %s)",
(source_id, doc.title, user_id)
)
elif doc.table_type == 'work':
cursor.execute(
f"INSERT INTO {table_name} (id, plan_name, created_by) VALUES (%s, %s, %s)",
(source_id, doc.title, user_id)
)
elif doc.table_type == 'job':
cursor.execute(
f"INSERT INTO {table_name} (id, file_name, created_by) VALUES (%s, %s, %s)",
(source_id, doc.title, user_id)
)
# 2. 插入主表
cursor.execute("""
INSERT INTO t_document_main
(id, title, content, created_by, source_type, source_id, whether_to_enter, primary_category_id, secondary_category_id, year, file_url, file_extension)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
""", (doc_id, doc.title, doc.content, user_id, doc.table_type, source_id, 0, doc.primary_category_id, doc.secondary_category_id, doc.year, doc.file_url, doc.file_extension))
conn.commit()
return ApiResponse(code=0, message="文档添加成功", data={"id": doc_id}, timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
except Exception as e:
conn.rollback()
raise e
finally:
cursor.close()
conn.close()
except Exception as e:
print(f"添加文档失败: {e}")
return ApiResponse(code=500, message=str(e), timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
@app.get("/api/v1/documents/detail/{doc_id}")
async def get_document_detail(doc_id: str, credentials: HTTPAuthorizationCredentials = Depends(security)):
"""获取文档详情 (关联查询子表)"""
print(f"🔍 正在获取文档详情: {doc_id}")
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
cursor = conn.cursor()
try:
# 1. 查询主表
cursor.execute("SELECT * FROM t_document_main WHERE id = %s", (doc_id,))
main_row = cursor.fetchone()
if not main_row:
print(f"❌ 文档不存在: {doc_id}")
return ApiResponse(code=404, message="文档不存在", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
columns = [desc[0] for desc in cursor.description]
doc = dict(zip(columns, main_row))
print(f"✅ 找到主表数据: {doc.get('title')}")
# 2. 查询子表
source_type = doc.get('source_type')
source_id = doc.get('source_id')
table_name = TABLE_MAP.get(source_type)
if table_name and source_id:
cursor.execute(f"SELECT * FROM {table_name} WHERE id = %s", (source_id,))
sub_row = cursor.fetchone()
if sub_row:
sub_columns = [desc[0] for desc in cursor.description]
sub_data = dict(zip(sub_columns, sub_row))
# 将子表字段映射到通用字段名,方便前端处理
if source_type == 'basis':
doc['standard_no'] = sub_data.get('standard_number')
doc['issuing_authority'] = sub_data.get('issuing_authority')
doc['release_date'] = str(sub_data.get('release_date')) if sub_data.get('release_date') else None
doc['document_type'] = sub_data.get('document_type')
doc['professional_field'] = sub_data.get('professional_field')
doc['validity'] = sub_data.get('validity')
elif source_type == 'work':
doc['project_name'] = sub_data.get('project_name')
doc['project_section'] = sub_data.get('project_section')
doc['issuing_authority'] = sub_data.get('compiling_unit')
doc['release_date'] = str(sub_data.get('compiling_date')) if sub_data.get('compiling_date') else None
elif source_type == 'job':
doc['issuing_authority'] = sub_data.get('issuing_department')
doc['document_type'] = sub_data.get('document_type')
doc['release_date'] = str(sub_data.get('publish_date')) if sub_data.get('publish_date') else None
# 格式化主表时间
if doc.get('created_time'):
doc['created_time'] = doc['created_time'].isoformat()
if doc.get('updated_time'):
doc['updated_time'] = doc['updated_time'].isoformat()
if doc.get('release_date') and not isinstance(doc['release_date'], str):
doc['release_date'] = doc['release_date'].isoformat()
return ApiResponse(code=0, message="获取详情成功", data=doc, timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
finally:
cursor.close()
conn.close()
except Exception as e:
print(f"获取文档详情失败: {e}")
return ApiResponse(code=500, message=str(e), timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
@app.get("/api/v1/documents/list")
async def get_document_list(
whether_to_enter: Optional[int] = None,
keyword: Optional[str] = None,
table_type: Optional[str] = None,
page: int = 1,
size: int = 50,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""获取文档列表 (从主表查询)"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
cursor = conn.cursor()
where_clauses = []
params = []
if table_type:
where_clauses.append("source_type = %s")
params.append(table_type)
if whether_to_enter is not None:
where_clauses.append("whether_to_enter = %s")
params.append(whether_to_enter)
if keyword:
where_clauses.append("(title LIKE %s OR content LIKE %s)")
params.extend([f"%{keyword}%", f"%{keyword}%"])
where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
offset = (page - 1) * size
sql = f"SELECT * FROM t_document_main {where_sql} ORDER BY created_time DESC LIMIT %s OFFSET %s"
params.extend([size, offset])
cursor.execute(sql, tuple(params))
columns = [desc[0] for desc in cursor.description]
items = []
for row in cursor.fetchall():
item = dict(zip(columns, row))
# 格式化时间
for key in ['created_time', 'updated_time', 'release_date']:
if item.get(key) and hasattr(item[key], 'isoformat'):
item[key] = item[key].isoformat()
items.append(item)
# 总数
count_sql = f"SELECT COUNT(*) FROM t_document_main {where_sql}"
cursor.execute(count_sql, tuple(params[:-2]))
total = cursor.fetchone()[0]
# 统计数据
cursor.execute("SELECT COUNT(*) FROM t_document_main")
all_total = cursor.fetchone()[0]
cursor.execute("SELECT COUNT(*) FROM t_document_main WHERE whether_to_enter = 1")
total_entered = cursor.fetchone()[0]
cursor.close()
conn.close()
return ApiResponse(
code=0,
message="查询成功",
data={
"items": items,
"total": total,
"page": page,
"size": size,
"all_total": all_total,
"total_entered": total_entered
},
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"获取文档列表失败: {e}")
return ApiResponse(code=500, message=str(e), timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
@app.post("/api/v1/documents/edit")
async def edit_document(doc: DocumentAdd, credentials: HTTPAuthorizationCredentials = Depends(security)):
"""编辑文档 (同步主表和子表)"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
if not doc.id or not doc.source_id:
return ApiResponse(code=400, message="缺少ID参数", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
cursor = conn.cursor()
table_name = TABLE_MAP.get(doc.table_type, "t_basis_of_preparation")
try:
# 1. 更新子表内容
if doc.table_type == 'basis':
cursor.execute(f"""
UPDATE {table_name}
SET chinese_name = %s, standard_number = %s, issuing_authority = %s,
release_date = %s, document_type = %s, professional_field = %s, validity = %s
WHERE id = %s
""", (doc.title, doc.standard_no, doc.issuing_authority, doc.release_date,
doc.document_type, doc.professional_field, doc.validity, doc.source_id))
elif doc.table_type == 'work':
cursor.execute(f"""
UPDATE {table_name}
SET plan_name = %s, project_name = %s, project_section = %s,
compiling_unit = %s, compiling_date = %s
WHERE id = %s
""", (doc.title, doc.project_name, doc.project_section, doc.issuing_authority,
doc.release_date, doc.source_id))
elif doc.table_type == 'job':
cursor.execute(f"""
UPDATE {table_name}
SET file_name = %s, issuing_department = %s, document_type = %s, publish_date = %s
WHERE id = %s
""", (doc.title, doc.issuing_authority, doc.document_type, doc.release_date, doc.source_id))
# 2. 更新主表内容
cursor.execute("""
UPDATE t_document_main
SET title = %s, content = %s, updated_time = NOW(),
primary_category_id = %s, secondary_category_id = %s, year = %s,
file_url = %s, file_extension = %s
WHERE id = %s
""", (doc.title, doc.content, doc.primary_category_id, doc.secondary_category_id, doc.year,
doc.file_url, doc.file_extension, doc.id))
conn.commit()
return ApiResponse(code=0, message="文档更新成功", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
except Exception as e:
conn.rollback()
raise e
finally:
cursor.close()
conn.close()
except Exception as e:
print(f"编辑文档失败: {e}")
return ApiResponse(code=500, message=str(e), timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
@app.post("/api/v1/documents/enter")
async def enter_document(data: dict, credentials: HTTPAuthorizationCredentials = Depends(security)):
"""文档入库"""
try:
doc_id = data.get("id")
if not doc_id:
return ApiResponse(code=400, message="缺少ID", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
payload = verify_token(credentials.credentials)
username = payload.get("username", "admin") if payload else "admin"
conn = get_db_connection()
if not conn:
return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
cursor = conn.cursor()
# 1. 更新主表
cursor.execute("UPDATE t_document_main SET whether_to_enter = 1, updated_time = NOW() WHERE id = %s", (doc_id,))
# 2. 尝试同步更新子表
try:
cursor.execute("SELECT source_type, source_id FROM t_document_main WHERE id = %s", (doc_id,))
res = cursor.fetchone()
if res and res[0] and res[1]:
s_type, s_id = res
sub_table = get_table_name(s_type)
if sub_table:
sub_sql = f"UPDATE {sub_table} SET whether_to_enter = 1, updated_at = NOW(), updated_by = %s WHERE id = %s"
try:
cursor.execute(sub_sql, (username, s_id))
except Exception as sub_e:
print(f"入库同步子表 {sub_table} 失败: {sub_e}")
except Exception as sync_e:
print(f"入库同步子表异常: {sync_e}")
conn.commit()
cursor.close()
conn.close()
return ApiResponse(code=0, message="入库成功", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
except Exception as e:
print(f"入库失败: {e}")
return ApiResponse(code=500, message=str(e), timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
@app.get("/api/v1/basic-info/list")
async def get_basic_info_list(
type: str,
page: int = 1,
size: int = 50,
keyword: Optional[str] = None,
title: Optional[str] = None,
standard_no: Optional[str] = None,
document_type: Optional[str] = None,
professional_field: Optional[str] = None,
validity: Optional[str] = None,
issuing_authority: Optional[str] = None,
release_date_start: Optional[str] = None,
release_date_end: Optional[str] = None,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""获取基本信息列表 (支持多条件检索)"""
try:
payload = verify_token(credentials.credentials)
if not payload:
return ApiResponse(code=401, message="无效的访问令牌", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
conn = get_db_connection()
if not conn:
return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
cursor = conn.cursor()
# 根据类型选择表名和字段映射
if type == 'basis':
table_name = "t_basis_of_preparation"
fields = "id, chinese_name as title, standard_number as standard_no, issuing_authority, release_date, document_type, professional_field, validity, created_by, created_time as created_at"
# 字段名映射供过滤使用
field_map = {
'title': 'chinese_name',
'standard_no': 'standard_number',
'issuing_authority': 'issuing_authority',
'release_date': 'release_date',
'document_type': 'document_type',
'professional_field': 'professional_field',
'validity': 'validity'
}
elif type == 'work':
table_name = "t_work_of_preparation"
fields = "id, plan_name as title, NULL as standard_no, compiling_unit as issuing_authority, compiling_date as release_date, NULL as document_type, NULL as professional_field, NULL as validity, created_by, created_time as created_at"
field_map = {
'title': 'plan_name',
'issuing_authority': 'compiling_unit',
'release_date': 'compiling_date'
}
elif type == 'job':
table_name = "t_job_of_preparation"
fields = "id, file_name as title, NULL as standard_no, issuing_department as issuing_authority, publish_date as release_date, document_type, NULL as professional_field, NULL as validity, created_by, created_time as created_at"
field_map = {
'title': 'file_name',
'issuing_authority': 'issuing_department',
'release_date': 'publish_date',
'document_type': 'document_type'
}
else:
return ApiResponse(code=400, message="无效的类型", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
where_clauses = []
params = []
# 1. 统一关键字搜索 (保持兼容)
if keyword:
if type == 'basis':
where_clauses.append("(chinese_name LIKE %s OR standard_number LIKE %s)")
params.extend([f"%{keyword}%", f"%{keyword}%"])
elif type == 'work':
where_clauses.append("plan_name LIKE %s")
params.append(f"%{keyword}%")
elif type == 'job':
where_clauses.append("file_name LIKE %s")
params.append(f"%{keyword}%")
# 2. 精细化检索
if title and 'title' in field_map:
where_clauses.append(f"{field_map['title']} LIKE %s")
params.append(f"%{title}%")
if standard_no and 'standard_no' in field_map:
where_clauses.append(f"{field_map['standard_no']} LIKE %s")
params.append(f"%{standard_no}%")
if document_type and 'document_type' in field_map:
where_clauses.append(f"{field_map['document_type']} = %s")
params.append(document_type)
if professional_field and 'professional_field' in field_map:
where_clauses.append(f"{field_map['professional_field']} = %s")
params.append(professional_field)
if validity and 'validity' in field_map:
where_clauses.append(f"{field_map['validity']} = %s")
params.append(validity)
if issuing_authority and 'issuing_authority' in field_map:
where_clauses.append(f"{field_map['issuing_authority']} LIKE %s")
params.append(f"%{issuing_authority}%")
if release_date_start and 'release_date' in field_map:
where_clauses.append(f"{field_map['release_date']} >= %s")
params.append(release_date_start)
if release_date_end and 'release_date' in field_map:
where_clauses.append(f"{field_map['release_date']} <= %s")
params.append(release_date_end)
where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
# 分页查询
offset = (page - 1) * size
sql = f"SELECT {fields} FROM {table_name}{where_sql} ORDER BY created_at DESC LIMIT %s OFFSET %s"
params.extend([size, offset])
cursor.execute(sql, tuple(params))
columns = [desc[0] for desc in cursor.description]
items = []
for row in cursor.fetchall():
item = dict(zip(columns, row))
# 格式化日期
for key in ['release_date', 'created_at']:
if item.get(key) and hasattr(item[key], 'isoformat'):
item[key] = item[key].isoformat()
elif item.get(key):
item[key] = str(item[key])
items.append(item)
# 总数查询
count_sql = f"SELECT COUNT(*) FROM {table_name}{where_sql}"
cursor.execute(count_sql, tuple(params[:-2]))
total = cursor.fetchone()[0]
cursor.close()
conn.close()
return ApiResponse(
code=0,
message="查询成功",
data={"items": items, "total": total, "page": page, "size": size},
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"查询基本信息失败: {e}")
return ApiResponse(code=500, message=f"服务器内部错误: {str(e)}", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
@app.get("/api/v1/documents/categories/primary")
async def get_primary_categories(credentials: HTTPAuthorizationCredentials = Depends(security)):
"""获取所有一级分类(仅保留指定的分类)"""
try:
payload = verify_token(credentials.credentials)
if not payload or not payload.get("is_superuser"):
return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
# 仅保留用户要求的分类
default_categories = ["办公制度", "行业标准", "法律法规", "施工方案", "施工图片"]
categories = [{"id": name, "name": name} for name in default_categories]
return ApiResponse(code=0, message="获取成功", data=categories, timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
except Exception as e:
return ApiResponse(code=500, message=str(e), timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
@app.get("/api/v1/documents/categories/secondary")
async def get_secondary_categories(primaryId: str, credentials: HTTPAuthorizationCredentials = Depends(security)):
"""根据一级分类获取二级分类(仅保留指定的分类)"""
try:
payload = verify_token(credentials.credentials)
if not payload or not payload.get("is_superuser"):
return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
# 针对“办公制度”的预设二级分类,其他分类暂时没有二级分类
categories = []
if primaryId == "办公制度":
secondary_names = ["采购", "报销", "审批"]
categories = [{"id": name, "name": name} for name in secondary_names]
return ApiResponse(code=0, message="获取成功", data=categories, timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
except Exception as e:
return ApiResponse(code=500, message=str(e), timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
@app.get("/api/v1/documents/search")
async def search_documents(
keyword: str,
primaryCategoryId: Optional[str] = None,
secondaryCategoryId: Optional[str] = None,
year: Optional[int] = None,
whether_to_enter: Optional[int] = None,
table_type: Optional[str] = "basis",
page: int = 1,
size: int = 50,
credentials: HTTPAuthorizationCredentials = Depends(security)
):
"""关键词搜索文档,统一调用 get_document_list 以支持组合过滤"""
return await get_document_list(
primaryCategoryId=primaryCategoryId,
secondaryCategoryId=secondaryCategoryId,
year=year,
whether_to_enter=whether_to_enter,
keyword=keyword,
table_type=table_type,
page=page,
size=size,
credentials=credentials
)
if __name__ == "__main__":
import uvicorn
# 查找可用端口
port = find_available_port()
if port is None:
print("❌ 无法找到可用端口 (8000-8010)")
print("请手动停止占用这些端口的进程")
sys.exit(1)
print("=" * 60)
print("🚀 SSO认证中心完整服务器")
print("=" * 60)
print(f"✅ 找到可用端口: {port}")
print(f"🌐 访问地址: http://localhost:{port}")
print(f"📚 API文档: http://localhost:{port}/docs")
print(f"❤️ 健康检查: http://localhost:{port}/health")
print(f"🔐 登录API: http://localhost:{port}/api/v1/auth/login")
print("=" * 60)
print("📝 前端配置:")
print(f" VITE_API_BASE_URL=http://localhost:{port}")
print("=" * 60)
print("👤 测试账号:")
print(" 用户名: admin")
print(" 密码: Admin123456")
print("=" * 60)
print("按 Ctrl+C 停止服务器")
print()
try:
uvicorn.run(
app,
host="0.0.0.0",
port=port,
log_level="info"
)
except KeyboardInterrupt:
print("\n👋 服务器已停止")
except Exception as e:
print(f"❌ 启动失败: {e}")
sys.exit(1)
@app.get("/api/v1/auth/captcha")
async def get_captcha():
"""获取验证码"""
try:
# 生成验证码
captcha_text, captcha_image = generate_captcha()
# 这里应该将验证码文本存储到缓存中(Redis或内存)
# 为了简化,我们暂时返回固定的验证码
captcha_id = secrets.token_hex(16)
return ApiResponse(
code=0,
message="获取验证码成功",
data={
"captcha_id": captcha_id,
"captcha_image": captcha_image,
"captcha_text": captcha_text # 生产环境中不应该返回这个
},
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
except Exception as e:
print(f"生成验证码错误: {e}")
return ApiResponse(
code=500001,
message="生成验证码失败",
timestamp=datetime.now(timezone.utc).isoformat()
).model_dump()
def generate_captcha():
"""生成验证码"""
try:
import random
import string
import base64
# 生成随机验证码文本
captcha_text = ''.join(random.choices(string.ascii_uppercase + string.digits, k=4))
# 创建一个简单的SVG验证码
svg_captcha = f"""
"""
svg_base64 = base64.b64encode(svg_captcha.encode('utf-8')).decode('utf-8')
return captcha_text, f"data:image/svg+xml;base64,{svg_base64}"
except Exception as e:
print(f"生成验证码失败: {e}")
# 返回默认验证码
return "1234", "data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iMTIwIiBoZWlnaHQ9IjQwIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPjxyZWN0IHdpZHRoPSIxMjAiIGhlaWdodD0iNDAiIGZpbGw9IiNmMGYwZjAiIHN0cm9rZT0iI2NjYyIvPjx0ZXh0IHg9IjYwIiB5PSIyNSIgZm9udC1mYW1pbHk9IkFyaWFsIiBmb250LXNpemU9IjE4IiB0ZXh0LWFuY2hvcj0ibWlkZGxlIiBmaWxsPSIjMzMzIj4xMjM0PC90ZXh0Pjwvc3ZnPg=="