chenkun 1 месяц назад
Родитель
Сommit
8b42a9ad3f
3 измененных файлов с 711 добавлено и 8 удалено
  1. 60 0
      fix_db_indexes.py
  2. 443 8
      full_server.py
  3. 208 0
      scripts/miner_u.py

+ 60 - 0
fix_db_indexes.py

@@ -0,0 +1,60 @@
+import os
+import pymysql
+from urllib.parse import urlparse
+from dotenv import load_dotenv
+
+load_dotenv()
+
+def fix_indexes():
+    """执行索引添加 SQL"""
+    database_url = os.getenv('DATABASE_URL', '')
+    if not database_url:
+        print("❌ 错误: 未在 .env 中找到 DATABASE_URL")
+        return
+        
+    parsed = urlparse(database_url)
+    config = {
+        'host': parsed.hostname or 'localhost',
+        'port': parsed.port or 3306,
+        'user': parsed.username or 'root',
+        'password': parsed.password or '',
+        'database': parsed.path[1:] if parsed.path else 'sso_db',
+        'charset': 'utf8mb4',
+        'autocommit': True
+    }
+    
+    print(f"📡 正在尝试连接数据库: {config['host']}...")
+    
+    conn = None
+    try:
+        conn = pymysql.connect(**config)
+        cursor = conn.cursor()
+        
+        tables = ['t_basis_of_preparation', 't_work_of_preparation', 't_job_of_preparation']
+        
+        for table in tables:
+            print(f"⚡ 正在为 {table} 添加索引...")
+            try:
+                # 检查索引是否已存在,防止重复添加报错
+                cursor.execute(f"SHOW INDEX FROM {table} WHERE Key_name = 'idx_enter_status'")
+                if cursor.fetchone():
+                    print(f"   ✅ {table} 的索引已存在,跳过。")
+                    continue
+                    
+                sql = f"ALTER TABLE {table} ADD INDEX idx_enter_status (whether_to_enter)"
+                cursor.execute(sql)
+                print(f"   ✅ {table} 索引添加成功!")
+            except Exception as e:
+                print(f"   ❌ {table} 处理失败: {e}")
+        
+        print("\n🎉 所有任务处理完成!现在您可以重新启动后端服务了。")
+        
+    except Exception as e:
+        print(f"\n❌ 数据库连接失败: {e}")
+        print("💡 提示: 请确保您已经关闭了 full_server.py,否则连接可能被占用。")
+    finally:
+        if conn:
+            conn.close()
+
+if __name__ == "__main__":
+    fix_indexes()

+ 443 - 8
full_server.py

@@ -15,11 +15,11 @@ sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
 from dotenv import load_dotenv
 from dotenv import load_dotenv
 load_dotenv()
 load_dotenv()
 
 
-from fastapi import FastAPI, HTTPException, Depends, Request
+from fastapi import FastAPI, HTTPException, Depends, Request, Response
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.middleware.cors import CORSMiddleware
 from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
 from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
 from pydantic import BaseModel
 from pydantic import BaseModel
-from typing import Optional, Any
+from typing import Optional, Any, Union
 import hashlib
 import hashlib
 import secrets
 import secrets
 # 修复JWT导入 - 确保使用正确的JWT库
 # 修复JWT导入 - 确保使用正确的JWT库
@@ -50,7 +50,7 @@ except (ImportError, AttributeError, TypeError) as e:
             print(f"❌ PyJWT安装失败: {install_error}")
             print(f"❌ PyJWT安装失败: {install_error}")
             raise ImportError("无法导入JWT库,请手动安装: pip install PyJWT")
             raise ImportError("无法导入JWT库,请手动安装: pip install PyJWT")
 
 
-from datetime import datetime, timedelta, timezone
+from datetime import datetime, timedelta, timezone, date
 import pymysql
 import pymysql
 from urllib.parse import urlparse
 from urllib.parse import urlparse
 
 
@@ -87,6 +87,33 @@ class ApiResponse(BaseModel):
     data: Optional[Any] = None
     data: Optional[Any] = None
     timestamp: str
     timestamp: str
 
 
+# 文档管理数据模型
+# --- 文档中心配置 ---
+TABLE_MAP = {
+    "basis": "t_basis_of_preparation", # 编制依据
+    "work": "t_work_of_preparation",   # 施工方案
+    "job": "t_job_of_preparation"      # 办公制度
+}
+
+def get_table_name(table_type: Optional[str]) -> str:
+    """根据类型获取对应的数据库表名,默认为编制依据"""
+    return TABLE_MAP.get(table_type, "t_basis_of_preparation")
+
+class DocumentAdd(BaseModel):
+    title: str
+    content: str
+    primary_category_id: Optional[Any] = None
+    secondary_category_id: Optional[Any] = None
+    year: Optional[int] = None
+    table_type: Optional[str] = "basis" # 增加表类型参数
+
+class DocumentListRequest(BaseModel):
+    primaryCategoryId: Optional[int] = None
+    secondaryCategoryId: Optional[int] = None
+    page: int = 1
+    size: int = 50
+    sort_by: str = "created_at"  # created_at or updated_at
+
 # 配置
 # 配置
 JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY", "dev-jwt-secret-key-12345")
 JWT_SECRET_KEY = os.getenv("JWT_SECRET_KEY", "dev-jwt-secret-key-12345")
 ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30"))
 ACCESS_TOKEN_EXPIRE_MINUTES = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", "30"))
@@ -176,6 +203,7 @@ app.add_middleware(
 )
 )
 
 
 security = HTTPBearer()
 security = HTTPBearer()
+security_optional = HTTPBearer(auto_error=False)
 
 
 @app.get("/")
 @app.get("/")
 async def root():
 async def root():
@@ -210,6 +238,8 @@ async def login(request: Request, login_data: LoginRequest):
     """用户登录"""
     """用户登录"""
     print(f"🔐 收到登录请求: username={login_data.username}")
     print(f"🔐 收到登录请求: username={login_data.username}")
     
     
+    conn = None
+    cursor = None
     try:
     try:
         # 获取数据库连接
         # 获取数据库连接
         print("📊 尝试连接数据库...")
         print("📊 尝试连接数据库...")
@@ -235,9 +265,6 @@ async def login(request: Request, login_data: LoginRequest):
         user_data = cursor.fetchone()
         user_data = cursor.fetchone()
         print(f"👤 用户查询结果: {user_data is not None}")
         print(f"👤 用户查询结果: {user_data is not None}")
         
         
-        cursor.close()
-        conn.close()
-        
         if not user_data:
         if not user_data:
             print("❌ 用户不存在")
             print("❌ 用户不存在")
             return ApiResponse(
             return ApiResponse(
@@ -307,6 +334,11 @@ async def login(request: Request, login_data: LoginRequest):
             message="服务器内部错误",
             message="服务器内部错误",
             timestamp=datetime.now(timezone.utc).isoformat()
             timestamp=datetime.now(timezone.utc).isoformat()
         ).model_dump()
         ).model_dump()
+    finally:
+        if cursor:
+            cursor.close()
+        if conn:
+            conn.close()
 
 
 @app.get("/api/v1/users/profile")
 @app.get("/api/v1/users/profile")
 async def get_user_profile(credentials: HTTPAuthorizationCredentials = Depends(security)):
 async def get_user_profile(credentials: HTTPAuthorizationCredentials = Depends(security)):
@@ -2277,6 +2309,16 @@ async def api_get_user_menus(credentials: HTTPAuthorizationCredentials = Depends
         
         
         menus = []
         menus = []
         for row in cursor.fetchall():
         for row in cursor.fetchall():
+            menu_id = str(row[0])
+            menu_name = str(row[2])
+            menu_title = str(row[3])
+            menu_path = str(row[4])
+            
+            # 只过滤掉明确不想要的“文档处理中心”
+            # 保留数据库中原本就有的“文档管理中心” (/admin/documents)
+            if "文档处理中心" in menu_title:
+                continue
+                
             menu = {
             menu = {
                 "id": row[0],
                 "id": row[0],
                 "parent_id": row[1],
                 "parent_id": row[1],
@@ -2293,8 +2335,9 @@ async def api_get_user_menus(credentials: HTTPAuthorizationCredentials = Depends
             }
             }
             menus.append(menu)
             menus.append(menu)
         
         
-        # 构建菜单树
-        menu_tree = build_menu_tree(menus)
+        # 构建菜单树前,过滤掉 button 类型的项,侧边栏只显示 menu 类型
+        sidebar_menus = [m for m in menus if m.get("menu_type") == "menu"]
+        menu_tree = build_menu_tree(sidebar_menus)
         
         
         cursor.close()
         cursor.close()
         conn.close()
         conn.close()
@@ -3506,6 +3549,398 @@ async def get_all_roles_simple(credentials: HTTPAuthorizationCredentials = Depen
         print(f"获取角色列表错误: {e}")
         print(f"获取角色列表错误: {e}")
         return ApiResponse(code=500, message="服务器内部错误", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
         return ApiResponse(code=500, message="服务器内部错误", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
 
 
+import httpx
+from fastapi.responses import HTMLResponse
+
+class BatchEnterRequest(BaseModel):
+    ids: list[int]
+    table_type: Optional[str] = "basis"
+
+class BatchDeleteRequest(BaseModel):
+    ids: list[Union[int, str]]
+    table_type: Optional[str] = "basis"
+
+class ConvertRequest(BaseModel):
+    id: Union[int, str]
+    table_type: Optional[str] = "basis"
+
+# --- 文档管理中心 API ---
+
+@app.get("/api/v1/documents/proxy-view")
+async def proxy_view(url: str, token: Optional[str] = None, credentials: Optional[HTTPAuthorizationCredentials] = Depends(security_optional)):
+    """抓取外部文档内容并返回,支持 HTML 和 PDF 等二进制文件。支持从 Header 或 Query 参数获取 Token。"""
+    try:
+        # 优先从 Header 获取,如果没有则从参数获取
+        actual_token = None
+        if credentials:
+            actual_token = credentials.credentials
+        elif token:
+            actual_token = token
+            
+        if not actual_token:
+            return ApiResponse(code=401, message="未提供认证令牌", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
+
+        payload = verify_token(actual_token)
+        if not payload or not payload.get("is_superuser"):
+            return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
+            
+        # 增加超时时间,支持大文件下载
+        async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client:
+            headers = {
+                "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36"
+            }
+            response = await client.get(url, headers=headers)
+            response.raise_for_status()
+            
+            content_type = response.headers.get("content-type", "").lower()
+            
+            # 如果是 PDF 或其他二进制文件
+            if "application/pdf" in content_type or any(ext in url.lower() for ext in [".pdf", ".png", ".jpg", ".jpeg", ".gif"]):
+                return Response(
+                    content=response.content,
+                    media_type=content_type,
+                    headers={"Content-Disposition": "inline"}
+                )
+            
+            # 默认处理为 HTML
+            try:
+                content = response.text
+                
+                # 简单的注入一些基础样式,确保内容在 iframe 中显示良好
+                base_style = """
+                <style>
+                    body { font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, "Helvetica Neue", Arial, sans-serif; padding: 20px; line-height: 1.6; color: #333; }
+                    img { max-width: 100%; height: auto; }
+                </style>
+                """
+                if "</head>" in content:
+                    content = content.replace("</head>", f"{base_style}</head>")
+                else:
+                    content = f"{base_style}{content}"
+                    
+                return HTMLResponse(content=content)
+            except Exception:
+                # 如果文本解析失败,返回原始字节
+                return Response(content=response.content, media_type=content_type)
+                
+    except Exception as e:
+        error_msg = f"<html><body><h3>无法加载内容</h3><p>错误原因: {str(e)}</p><p>URL: {url}</p></body></html>"
+        return HTMLResponse(content=error_msg, status_code=500)
+
+@app.post("/api/v1/documents/batch-enter")
+async def batch_enter_knowledge_base(req: BatchEnterRequest, credentials: HTTPAuthorizationCredentials = Depends(security)):
+    """批量将文档加入知识库"""
+    try:
+        payload = verify_token(credentials.credentials)
+        if not payload or not payload.get("is_superuser"):
+            return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
+            
+        conn = get_db_connection()
+        if not conn:
+            return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
+        
+        cursor = conn.cursor()
+        table_name = get_table_name(req.table_type)
+        # 批量更新 whether_to_enter 为 1
+        # 只更新尚未入库的数据 (whether_to_enter = 0)
+        placeholders = ', '.join(['%s'] * len(req.ids))
+        sql = f"UPDATE {table_name} SET whether_to_enter = 1, updated_at = NOW() WHERE id IN ({placeholders}) AND whether_to_enter = 0"
+        cursor.execute(sql, req.ids)
+        conn.commit()
+        
+        affected_rows = cursor.rowcount
+        cursor.close()
+        conn.close()
+        
+        message = f"成功将 {affected_rows} 条数据加入知识库"
+        if affected_rows < len(req.ids):
+            message += f"(跳过了 {len(req.ids) - affected_rows} 条已入库数据)"
+            
+        return ApiResponse(code=0, message=message, timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
+    except Exception as e:
+        return ApiResponse(code=500, message=f"批量操作失败: {str(e)}", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
+
+@app.post("/api/v1/documents/batch-delete")
+async def batch_delete_documents(req: BatchDeleteRequest, credentials: HTTPAuthorizationCredentials = Depends(security)):
+    """批量删除文档"""
+    conn = None
+    cursor = None
+    try:
+        payload = verify_token(credentials.credentials)
+        if not payload or not payload.get("is_superuser"):
+            return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
+            
+        conn = get_db_connection()
+        if not conn:
+            return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
+        
+        cursor = conn.cursor()
+        table_name = get_table_name(req.table_type)
+        
+        if not req.ids:
+            return ApiResponse(code=400, message="未指定要删除的文档 ID", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
+            
+        placeholders = ', '.join(['%s'] * len(req.ids))
+        sql = f"DELETE FROM {table_name} WHERE id IN ({placeholders})"
+        cursor.execute(sql, req.ids)
+        conn.commit()
+        
+        affected_rows = cursor.rowcount
+        
+        return ApiResponse(
+            code=0, 
+            message=f"成功删除 {affected_rows} 条文档数据", 
+            timestamp=datetime.now(timezone.utc).isoformat()
+        ).model_dump()
+    except Exception as e:
+        print(f"批量删除失败: {e}")
+        return ApiResponse(code=500, message=f"批量删除失败: {str(e)}", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
+    finally:
+        if cursor:
+            cursor.close()
+        if conn:
+            conn.close()
+
+@app.post("/api/v1/documents/convert")
+async def convert_document(req: ConvertRequest, credentials: HTTPAuthorizationCredentials = Depends(security)):
+    """异步启动文档转换"""
+    import subprocess
+    try:
+        payload = verify_token(credentials.credentials)
+        if not payload or not payload.get("is_superuser"):
+            return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
+        
+        # 启动后台进程执行转换
+        # 脚本位于 d:\UGit\LQAdminPlatform\scripts\miner_u.py
+        script_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "scripts", "miner_u.py"))
+        # 使用当前 python 解释器
+        python_exe = sys.executable
+        
+        # 异步启动,不等待结束
+        subprocess.Popen([python_exe, script_path, str(req.table_type), str(req.id)], 
+                         stdout=subprocess.DEVNULL, 
+                         stderr=subprocess.DEVNULL,
+                         creationflags=subprocess.CREATE_NO_WINDOW if os.name == 'nt' else 0)
+        
+        return ApiResponse(
+            code=0, 
+            message="转换任务已启动", 
+            timestamp=datetime.now(timezone.utc).isoformat()
+        ).model_dump()
+    except Exception as e:
+        print(f"启动转换失败: {e}")
+        return ApiResponse(code=500, message=f"启动转换失败: {str(e)}", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
+
+@app.post("/api/v1/documents/add")
+async def add_document(doc: DocumentAdd, credentials: HTTPAuthorizationCredentials = Depends(security)):
+    """添加新文档"""
+    try:
+        payload = verify_token(credentials.credentials)
+        if not payload or not payload.get("is_superuser"):
+            return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
+            
+        conn = get_db_connection()
+        if not conn:
+            return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
+        
+        cursor = conn.cursor()
+        table_name = get_table_name(doc.table_type)
+        # 修正列名:reference_basis -> reference_basis_list
+        sql = f"""
+            INSERT INTO {table_name} 
+            (chinese_name, reference_basis_list, document_type, professional_field, release_date, created_at, updated_at)
+            VALUES (%s, %s, %s, %s, %s, NOW(), NOW())
+        """
+        # 构造日期:如果是年份,转为 YYYY-01-01
+        release_date = f"{doc.year}-01-01" if doc.year else None
+        
+        cursor.execute(sql, (doc.title, doc.content, str(doc.primary_category_id) if doc.primary_category_id else None, 
+                             str(doc.secondary_category_id) if doc.secondary_category_id else None, release_date))
+        conn.commit()
+        cursor.close()
+        conn.close()
+        
+        return ApiResponse(code=0, message="文档添加成功", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
+    except Exception as e:
+        print(f"添加文档错误: {e}")
+        return ApiResponse(code=500, message=f"服务器内部错误: {str(e)}", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
+
+@app.get("/api/v1/documents/list")
+async def get_document_list(
+    primaryCategoryId: Optional[str] = None,
+    secondaryCategoryId: Optional[str] = None,
+    year: Optional[int] = None,
+    whether_to_enter: Optional[int] = None,
+    keyword: Optional[str] = None,
+    table_type: Optional[str] = "basis",
+    page: int = 1, 
+    size: int = 50,
+    sort_by: str = "created_at",
+    credentials: HTTPAuthorizationCredentials = Depends(security)
+):
+    """获取文档列表(支持过滤与搜索)"""
+    conn = None
+    cursor = None
+    try:
+        payload = verify_token(credentials.credentials)
+        if not payload or not payload.get("is_superuser"):
+            return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
+            
+        conn = get_db_connection()
+        if not conn:
+            return ApiResponse(code=500, message="数据库连接失败", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
+        
+        cursor = conn.cursor()
+        table_name = get_table_name(table_type)
+        
+        where_clauses = []
+        params = []
+        
+        if primaryCategoryId:
+            where_clauses.append("document_type = %s")
+            params.append(primaryCategoryId)
+        if secondaryCategoryId:
+            where_clauses.append("professional_field = %s")
+            params.append(secondaryCategoryId)
+        if year:
+            where_clauses.append("YEAR(release_date) = %s")
+            params.append(year)
+        if whether_to_enter is not None:
+            where_clauses.append("CAST(whether_to_enter AS UNSIGNED) = %s")
+            params.append(whether_to_enter)
+        if keyword:
+            where_clauses.append("(chinese_name LIKE %s OR reference_basis_list LIKE %s OR standard_no LIKE %s)")
+            like_keyword = f"%{keyword}%"
+            params.extend([like_keyword, like_keyword, like_keyword])
+            
+        where_stmt = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
+        
+        # 排序逻辑:按创建时间倒序
+        sort_field = "created_at" if sort_by == "created_at" else "updated_at"
+        order_by = f"ORDER BY {sort_field} DESC"
+        
+        # 分页
+        offset = (page - 1) * size
+        
+        # 返回更多字段
+        sql = f"""
+            SELECT id, chinese_name as title, reference_basis_list as content, 
+            document_type, professional_field, 
+            YEAR(release_date) as year, release_date, standard_no, status, 
+            CAST(whether_to_enter AS UNSIGNED) as whether_to_enter, file_url,
+            conversion_status, conversion_progress, conversion_error,
+            created_at, updated_at 
+            FROM {table_name} {where_stmt} 
+            {order_by} LIMIT %s OFFSET %s
+        """
+        params.extend([size, offset])
+        
+        cursor.execute(sql, params)
+        columns = [col[0] for col in cursor.description]
+        items = [dict(zip(columns, row)) for row in cursor.fetchall()]
+        
+        # 格式化时间
+        for item in items:
+            for key, value in item.items():
+                if isinstance(value, (datetime, date)):
+                    item[key] = value.isoformat()
+        
+        # 获取总数
+        count_sql = f"SELECT COUNT(*) FROM {table_name} {where_stmt}"
+        cursor.execute(count_sql, params[:-2])
+        total = cursor.fetchone()[0]
+        
+        # 优化统计查询:合并全局总数和已入库总数的查询,减少数据库交互
+        stats_sql = f"SELECT COUNT(*), SUM(CASE WHEN CAST(whether_to_enter AS UNSIGNED) = 1 THEN 1 ELSE 0 END) FROM {table_name}"
+        cursor.execute(stats_sql)
+        stats_result = cursor.fetchone()
+        
+        all_total = 0
+        total_entered = 0
+        if stats_result:
+            all_total = stats_result[0] or 0
+            total_entered = int(stats_result[1] or 0)
+        
+        return ApiResponse(
+            code=0,
+            message="获取成功",
+            data={
+                "items": items, 
+                "total": total, 
+                "all_total": all_total,
+                "total_entered": total_entered,
+                "page": page, 
+                "size": size
+            },
+            timestamp=datetime.now(timezone.utc).isoformat()
+        ).model_dump()
+    except Exception as e:
+        print(f"获取文档列表错误: {e}")
+        return ApiResponse(code=500, message=f"服务器内部错误: {str(e)}", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
+    finally:
+        if cursor:
+            cursor.close()
+        if conn:
+            conn.close()
+
+@app.get("/api/v1/documents/categories/primary")
+async def get_primary_categories(credentials: HTTPAuthorizationCredentials = Depends(security)):
+    """获取所有一级分类(仅保留指定的分类)"""
+    try:
+        payload = verify_token(credentials.credentials)
+        if not payload or not payload.get("is_superuser"):
+            return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
+            
+        # 仅保留用户要求的分类
+        default_categories = ["办公制度", "行业标准", "法律法规", "施工方案", "施工图片"]
+        categories = [{"id": name, "name": name} for name in default_categories]
+        return ApiResponse(code=0, message="获取成功", data=categories, timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
+    except Exception as e:
+        return ApiResponse(code=500, message=str(e), timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
+
+@app.get("/api/v1/documents/categories/secondary")
+async def get_secondary_categories(primaryId: str, credentials: HTTPAuthorizationCredentials = Depends(security)):
+    """根据一级分类获取二级分类(仅保留指定的分类)"""
+    try:
+        payload = verify_token(credentials.credentials)
+        if not payload or not payload.get("is_superuser"):
+            return ApiResponse(code=403, message="权限不足", timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
+            
+        # 针对“办公制度”的预设二级分类,其他分类暂时没有二级分类
+        categories = []
+        if primaryId == "办公制度":
+            secondary_names = ["采购", "报销", "审批"]
+            categories = [{"id": name, "name": name} for name in secondary_names]
+        
+        return ApiResponse(code=0, message="获取成功", data=categories, timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
+    except Exception as e:
+        return ApiResponse(code=500, message=str(e), timestamp=datetime.now(timezone.utc).isoformat()).model_dump()
+
+@app.get("/api/v1/documents/search")
+async def search_documents(
+    keyword: str, 
+    primaryCategoryId: Optional[str] = None,
+    secondaryCategoryId: Optional[str] = None,
+    year: Optional[int] = None,
+    whether_to_enter: Optional[int] = None,
+    table_type: Optional[str] = "basis",
+    page: int = 1, 
+    size: int = 50,
+    credentials: HTTPAuthorizationCredentials = Depends(security)
+):
+    """关键词搜索文档,统一调用 get_document_list 以支持组合过滤"""
+    return await get_document_list(
+        primaryCategoryId=primaryCategoryId,
+        secondaryCategoryId=secondaryCategoryId,
+        year=year,
+        whether_to_enter=whether_to_enter,
+        keyword=keyword,
+        table_type=table_type,
+        page=page,
+        size=size,
+        credentials=credentials
+    )
+
 if __name__ == "__main__":
 if __name__ == "__main__":
     import uvicorn
     import uvicorn
     
     

+ 208 - 0
scripts/miner_u.py

@@ -0,0 +1,208 @@
+import os
+import time
+import json
+import requests
+import pymysql
+import zipfile
+import io
+from pathlib import Path
+from urllib.parse import urlparse
+from dotenv import load_dotenv
+
+# 加载环境变量 - 配置文件在脚本所在目录的上一级
+env_path = os.path.join(os.path.dirname(__file__), "..", ".env")
+load_dotenv(dotenv_path=env_path)
+
+TOKEN = "eyJ0eXBlIjoiSldUIiwiYWxnIjoiSFM1MTIifQ.eyJqdGkiOiI1MzgwMDYyNSIsInJvbCI6IlJPTEVfUkVHSVNURVIiLCJpc3MiOiJPcGVuWExhYiIsImlhdCI6MTc2Nzg1OTg5NywiY2xpZW50SWQiOiJsa3pkeDU3bnZ5MjJqa3BxOXgydyIsInBob25lIjoiMTgwMzA5ODIxNTQiLCJvcGVuSWQiOm51bGwsInV1aWQiOiI0NTYyZTUyNi1iZjE3LTRhMmItODExMi04YmM5ZjNjYzMwZGMiLCJlbWFpbCI6IiIsImV4cCI6MTc2OTA2OTQ5N30.mNH7afPPANNQq_BRsBOlbk-2P7e_ewdfzPQXO4woeoT15mDEbPKc45Auk_BuRuNaAS-Gm2GK3qKGjQ2VDtepvA"
+API_APPLY = "https://mineru.net/api/v4/file-urls/batch"
+API_BATCH_RESULT = "https://mineru.net/api/v4/extract-results/batch/{}"
+
+HEADERS = {
+    "Content-Type": "application/json",
+    "Authorization": f"Bearer {TOKEN}",
+}
+
+SUPPORTED_SUFFIX = {".pdf", ".doc", ".docx", ".ppt", ".pptx", ".png", ".jpg", ".jpeg", ".html"}
+
+def get_db_connection():
+    database_url = os.getenv('DATABASE_URL')
+    if not database_url:
+        print("DATABASE_URL not found in environment")
+        return None
+    try:
+        parsed = urlparse(database_url)
+        return pymysql.connect(
+            host=parsed.hostname,
+            port=parsed.port or 3306,
+            user=parsed.username,
+            password=parsed.password,
+            database=parsed.path[1:],
+            charset='utf8mb4',
+            autocommit=True
+        )
+    except Exception as e:
+        print(f"Database connection error: {e}")
+        return None
+
+def update_db_status(table_name, doc_id, status=None, progress=None, error=None):
+    conn = get_db_connection()
+    if not conn:
+        return
+    try:
+        with conn.cursor() as cursor:
+            updates = []
+            params = []
+            if status is not None:
+                updates.append("conversion_status = %s")
+                params.append(status)
+            if progress is not None:
+                updates.append("conversion_progress = %s")
+                params.append(progress)
+            if error is not None:
+                updates.append("conversion_error = %s")
+                params.append(error)
+            
+            if not updates:
+                return
+                
+            sql = f"UPDATE {table_name} SET {', '.join(updates)} WHERE id = %s"
+            params.append(doc_id)
+            cursor.execute(sql, params)
+    except Exception as e:
+        print(f"Update DB failed: {e}")
+    finally:
+        conn.close()
+
+def apply_upload_urls(files_meta, model_version="vlm"):
+    payload = {
+        "files": files_meta,
+        "model_version": model_version,
+    }
+    r = requests.post(API_APPLY, headers=HEADERS, json=payload, timeout=60)
+    r.raise_for_status()
+    j = r.json()
+    if j.get("code") != 0:
+        raise RuntimeError(f"apply upload urls failed: {j.get('msg')}")
+    return j["data"]["batch_id"], j["data"]["file_urls"]
+
+def upload_files(file_data_list, upload_urls):
+    for data, url in zip(file_data_list, upload_urls):
+        res = requests.put(url, data=data, timeout=300)
+        if res.status_code != 200:
+            raise RuntimeError(f"upload failed to {url}, status={res.status_code}")
+
+def poll_batch(batch_id, interval_sec=5, timeout_sec=1800):
+    deadline = time.time() + timeout_sec
+    while True:
+        r = requests.get(API_BATCH_RESULT.format(batch_id), headers=HEADERS, timeout=60)
+        r.raise_for_status()
+        j = r.json()
+        if j.get("code") != 0:
+            raise RuntimeError(f"poll failed: {j.get('msg')}")
+        results = j["data"]["extract_result"]
+        states = [it.get("state") for it in results]
+
+        if all(s in ("done", "failed") for s in states):
+            return results
+
+        if time.time() > deadline:
+            raise TimeoutError(f"poll timeout for batch_id={batch_id}")
+        time.sleep(interval_sec)
+
+def process_document(table_name, doc_id, chinese_name, file_url, out_dir):
+    try:
+        # 1. 更新状态:开始转换
+        update_db_status(table_name, doc_id, status=1, progress=10)
+        
+        # 2. 下载原始文件
+        print(f"Downloading {file_url}...")
+        resp = requests.get(file_url, timeout=60)
+        resp.raise_for_status()
+        file_content = resp.content
+        
+        file_ext = Path(urlparse(file_url).path).suffix.lower()
+        if not file_ext:
+            file_ext = ".pdf" # Default
+            
+        file_name = f"{chinese_name}{file_ext}"
+        update_db_status(table_name, doc_id, progress=30)
+        
+        # 3. 提交到 MinerU
+        files_meta = [{"name": file_name, "data_id": doc_id}]
+        batch_id, upload_urls = apply_upload_urls(files_meta)
+        upload_files([file_content], upload_urls)
+        
+        update_db_status(table_name, doc_id, progress=50)
+        
+        # 4. 轮询结果
+        results = poll_batch(batch_id)
+        result = results[0]
+        
+        if result.get("state") == "done":
+            zip_url = result.get("full_zip_url")
+            if zip_url:
+                # 5. 下载并处理结果
+                update_db_status(table_name, doc_id, progress=80)
+                zip_resp = requests.get(zip_url, timeout=300)
+                zip_resp.raise_for_status()
+                
+                # 解压并保存 Markdown
+                with zipfile.ZipFile(io.BytesIO(zip_resp.content)) as z:
+                    # 查找 .md 文件
+                    md_files = [f for f in z.namelist() if f.endswith(".md")]
+                    if md_files:
+                        md_content = z.read(md_files[0])
+                        save_path = Path(out_dir) / f"{chinese_name}.md"
+                        save_path.parent.mkdir(parents=True, exist_ok=True)
+                        with open(save_path, "wb") as f:
+                            f.write(md_content)
+                        print(f"Saved Markdown to {save_path}")
+                
+                update_db_status(table_name, doc_id, status=2, progress=100)
+                return True
+            else:
+                raise RuntimeError("No zip URL in result")
+        else:
+            err_msg = result.get("err_msg", "Unknown error")
+            raise RuntimeError(f"MinerU extraction failed: {err_msg}")
+            
+    except Exception as e:
+        print(f"Process failed: {e}")
+        update_db_status(table_name, doc_id, status=3, error=str(e))
+        return False
+
+def main_cli(table_type, doc_id, out_dir=r"d:\UGit\MinerU"):
+    # 获取表名
+    TABLE_MAP = {
+        "basis": "t_basis_of_preparation",
+        "work": "t_work_of_preparation",
+        "job": "t_job_of_preparation"
+    }
+    table_name = TABLE_MAP.get(table_type, "t_basis_of_preparation")
+    
+    # 从数据库获取详细信息
+    conn = get_db_connection()
+    if not conn:
+        print("Database connection failed")
+        return
+        
+    try:
+        with conn.cursor() as cursor:
+            cursor.execute(f"SELECT chinese_name, file_url FROM {table_name} WHERE id = %s", (doc_id,))
+            row = cursor.fetchone()
+            if not row:
+                print(f"Document not found: {doc_id} in {table_name}")
+                return
+            chinese_name, file_url = row
+            
+        process_document(table_name, doc_id, chinese_name, file_url, out_dir)
+    finally:
+        conn.close()
+
+if __name__ == "__main__":
+    # 示例用法:python miner_u.py basis <doc_id>
+    import sys
+    if len(sys.argv) > 2:
+        main_cli(sys.argv[1], sys.argv[2])
+    else:
+        print("Usage: python miner_u.py <table_type> <doc_id>")