Просмотр исходного кода

Merge remote-tracking branch 'origin/zkn' into dev

zkn 2 дней назад
Родитель
Сommit
34b42fa979

+ 1 - 1
shudao-chat-py/database.py

@@ -11,7 +11,7 @@ engine = create_engine(
     DATABASE_URL,
     pool_size=settings.database.pool_size,
     max_overflow=settings.database.max_overflow,
-    pool_recycle=settings.database.pool_recycle,
+    pool_recycle=600,
     pool_pre_ping=True,
     echo=settings.app.debug
 )

+ 73 - 109
shudao-chat-py/routers/chat.py

@@ -1,6 +1,5 @@
 from fastapi import APIRouter, Depends, Request
 from fastapi.responses import StreamingResponse, JSONResponse
-from sqlalchemy import or_
 from sqlalchemy.orm import Session
 from pydantic import BaseModel
 from typing import Optional
@@ -37,28 +36,6 @@ def _build_conversation_title(conversation: AIConversation) -> str:
     return _build_conversation_preview(conversation.content or "", limit=30)
 
 
-def _is_exam_workshop_conversation(conversation: Optional[AIConversation]) -> bool:
-    if not conversation:
-        return False
-    return conversation.business_type == 3 or bool((conversation.exam_name or "").strip())
-
-
-def _resolve_conversation_metadata(
-    conversation: Optional[AIConversation],
-    requested_business_type: int,
-    requested_exam_name: str,
-) -> tuple[int, str]:
-    requested_exam_name = (requested_exam_name or "").strip()
-
-    if _is_exam_workshop_conversation(conversation):
-        return 3, requested_exam_name or (conversation.exam_name or "").strip()
-
-    if requested_business_type == 3:
-        return 3, requested_exam_name
-
-    return requested_business_type, ""
-
-
 def _refresh_conversation_snapshot(db: Session, conversation_id: int, user_id: int) -> None:
     latest_message = (
         db.query(AIMessage)
@@ -200,31 +177,13 @@ async def send_deepseek_message(
 
         conversation_id = data.conversation_id or data.ai_conversation_id
 
-        existing_conversation = None
-        if conversation_id:
-            existing_conversation = (
-                db.query(AIConversation)
-                .filter(
-                    AIConversation.id == conversation_id,
-                    AIConversation.user_id == user.user_id,
-                    AIConversation.is_deleted == 0,
-                )
-                .first()
-            )
-
-        effective_business_type, effective_exam_name = _resolve_conversation_metadata(
-            existing_conversation,
-            data.business_type,
-            data.exam_name,
-        )
-
         # 创建或获取对话
         if not conversation_id:
             conversation = AIConversation(
                 user_id=user.user_id,
                 content=message[:100],
-                business_type=effective_business_type,
-                exam_name=effective_exam_name,
+                business_type=data.business_type,
+                exam_name=data.exam_name if data.business_type == 3 else "",
                 created_at=int(time.time()),
                 updated_at=int(time.time()),
                 is_deleted=0,
@@ -241,8 +200,8 @@ async def send_deepseek_message(
                 AIConversation.is_deleted == 0,
             ).update({
                 "content": message[:100],
-                "business_type": effective_business_type,
-                "exam_name": effective_exam_name,
+                "business_type": data.business_type,
+                "exam_name": data.exam_name if data.business_type == 3 else "",
                 "updated_at": int(time.time()),
             })
             db.commit()
@@ -287,8 +246,9 @@ async def send_deepseek_message(
                 except Exception:
                     response_text = qwen_response
             except Exception as e:
-                logger.error(f"[send_deepseek_message] AI问答异常: {e}")
-                response_text = f"处理失败: {str(e)}"
+                error_detail = str(e).strip() if str(e).strip() else f"未知错误({type(e).__name__})"
+                logger.error(f"[send_deepseek_message] AI问答异常: {type(e).__name__}: {error_detail}")
+                response_text = f"处理失败: {error_detail}"
 
         elif data.business_type == 1:
             # PPT大纲生成
@@ -308,8 +268,9 @@ async def send_deepseek_message(
 
                 response_text = await qwen_service.chat(messages)
             except Exception as e:
-                logger.error(f"[send_deepseek_message] PPT大纲生成异常: {e}")
-                response_text = f"处理失败: {str(e)}"
+                error_detail = str(e).strip() if str(e).strip() else f"未知错误({type(e).__name__})"
+                logger.error(f"[send_deepseek_message] PPT大纲生成异常: {type(e).__name__}: {error_detail}")
+                response_text = f"处理失败: {error_detail}"
 
         elif data.business_type == 2:
             # AI写作
@@ -329,8 +290,9 @@ async def send_deepseek_message(
 
                 response_text = await qwen_service.chat(messages)
             except Exception as e:
-                logger.error(f"[send_deepseek_message] AI写作异常: {e}")
-                response_text = f"处理失败: {str(e)}"
+                error_detail = str(e).strip() if str(e).strip() else f"未知错误({type(e).__name__})"
+                logger.error(f"[send_deepseek_message] AI写作异常: {type(e).__name__}: {error_detail}")
+                response_text = f"处理失败: {error_detail}"
 
         elif data.business_type == 3:
             # 考试工坊:生成题目
@@ -386,16 +348,16 @@ async def send_deepseek_message(
                 _refresh_conversation_snapshot(db, conv_id, user.user_id)
                 db.commit()
 
-                if effective_exam_name:
+                if data.exam_name:
                     db.query(AIConversation).filter(AIConversation.id == conv_id).update(
-                        {"business_type": 3,
-                            "exam_name": effective_exam_name,
+                        {"exam_name": data.exam_name,
                             "updated_at": int(time.time())}
                     )
                     db.commit()
             except Exception as e:
-                logger.error(f"[send_deepseek_message] 考试工坊异常: {e}")
-                response_text = f"处理失败: {str(e)}"
+                error_detail = str(e).strip() if str(e).strip() else f"未知错误({type(e).__name__})"
+                logger.error(f"[send_deepseek_message] 考试工坊异常: {type(e).__name__}: {error_detail}")
+                response_text = f"处理失败: {error_detail}"
 
         else:
             return {"statusCode": 400, "msg": f"不支持的业务类型: {data.business_type}"}
@@ -411,7 +373,7 @@ async def send_deepseek_message(
                 "content": response_text,
                 "message": response_text,
                 "user_id": user.user_id,
-                "business_type": effective_business_type,
+                "business_type": data.business_type,
             },
         }
     except Exception as e:
@@ -471,18 +433,9 @@ async def get_history_record(
     )
 
     if business_type is not None:
-        if business_type == 3:
-            conversations_query = conversations_query.filter(
-                or_(
-                    AIConversation.business_type == 3,
-                    AIConversation.exam_name.isnot(None),
-                    AIConversation.exam_name != "",
-                )
-            )
-        else:
-            conversations_query = conversations_query.filter(
-                AIConversation.business_type == business_type
-            )
+        conversations_query = conversations_query.filter(
+            AIConversation.business_type == business_type
+        )
 
     total = conversations_query.count()
     conversations = (
@@ -778,50 +731,61 @@ async def stream_chat_with_db(request: Request, data: StreamChatWithDBRequest):
             # 4. 发送 initial 事件
             yield f"data: {json.dumps({'type': 'initial', 'ai_conversation_id': conv_id, 'ai_message_id': ai_msg.id}, ensure_ascii=False)}\n\n"
 
-            # 5. RAG 检索
+            # 5. RAG search
             rag_context = await _rag_search(message, top_k=10)
 
-            # 6. 获取历史上下文(最近 4 条,2 轮对话)
-            history_msgs = (
-                db.query(AIMessage)
-                .filter(
-                    AIMessage.ai_conversation_id == conv_id,
-                    AIMessage.id < ai_msg.id,
-                    AIMessage.is_deleted == 0,
+            if data.business_type in (1, 2):
+                # PPT outline / AI writing: use dedicated prompt
+                prompt_name = "ppt_outline" if data.business_type == 1 else "document_writing"
+                system_content = load_prompt(
+                    prompt_name,
+                    userMessage=message,
+                    contextJSON=rag_context if rag_context else "?????????"
                 )
-                .order_by(AIMessage.updated_at.desc())
-                .limit(4)
-                .all()
-            )
-            history_msgs.reverse()
-
-            history_context = ""
-            for msg in history_msgs:
-                role = "用户" if msg.type == "user" else "助手"
-                history_context += f"{role}: {msg.content}\n\n"
-
-            # 7. 构建完整 prompt
-            # 构建上下文JSON
-            context_parts = []
-            if rag_context:
-                context_parts.append(f"知识库内容:\n{rag_context}")
-            if data.online_search_content:
-                context_parts.append(f"联网搜索结果:\n{data.online_search_content}")
-
-            context_json = "\n\n".join(
-                context_parts) if context_parts else "暂无相关知识库内容"
-
-            # 使用prompt加载器加载最终回答prompt
-            system_content = load_prompt(
-                "final_answer",
-                userMessage=message,
-                contextJSON=context_json,
-                historyContext=history_context if history_context else ""
-            )
 
-            messages = [
-                {"role": "user", "content": system_content},
-            ]
+                messages = [
+                    {"role": "user", "content": system_content},
+                ]
+            else:
+                # 6. History context (last 4 items, 2 turns)
+                history_msgs = (
+                    db.query(AIMessage)
+                    .filter(
+                        AIMessage.ai_conversation_id == conv_id,
+                        AIMessage.id < ai_msg.id,
+                        AIMessage.is_deleted == 0,
+                    )
+                    .order_by(AIMessage.updated_at.desc())
+                    .limit(4)
+                    .all()
+                )
+                history_msgs.reverse()
+
+                history_context = ""
+                for msg in history_msgs:
+                    role = "??" if msg.type == "user" else "??"
+                    history_context += f"{role}: {msg.content}\n\n"
+
+                # 7. Build final prompt
+                context_parts = []
+                if rag_context:
+                    context_parts.append(f"??????\n{rag_context}")
+                if data.online_search_content:
+                    context_parts.append(f"???????\n{data.online_search_content}")
+
+                context_json = "\n\n".join(
+                    context_parts) if context_parts else "?????????"
+
+                system_content = load_prompt(
+                    "final_answer",
+                    userMessage=message,
+                    contextJSON=context_json,
+                    historyContext=history_context if history_context else ""
+                )
+
+                messages = [
+                    {"role": "user", "content": system_content},
+                ]
 
             # 8. 流式输出并收集完整回复
             full_response = ""

+ 44 - 37
shudao-chat-py/routers/file.py

@@ -1,14 +1,10 @@
-from fastapi import APIRouter, Depends, Request, UploadFile, File
-from fastapi.responses import FileResponse
-from sqlalchemy.orm import Session
-from pydantic import BaseModel
 from typing import Optional
-from database import get_db
-from models.total import PolicyFile
-from services.oss_service import oss_service
-import time
 import json
-import os
+
+from fastapi import APIRouter, File, Request, UploadFile
+from pydantic import BaseModel
+
+from services.oss_service import oss_service
 
 router = APIRouter()
 
@@ -16,20 +12,20 @@ router = APIRouter()
 @router.post("/oss/upload")
 async def upload(
     request: Request,
-    file: UploadFile = File(...)
+    file: UploadFile = File(...),
 ):
-    """OSS上传 - 对齐Go版本函数名"""
+    """Upload a generic file to OSS."""
     user = request.state.user
     if not user:
         return {"statusCode": 401, "msg": "未授权"}
-    
+
     try:
         content = await file.read()
         file_url = oss_service.upload_file(content, file.filename)
         return {
             "statusCode": 200,
             "msg": "上传成功",
-            "data": {"file_url": file_url}
+            "data": {"file_url": file_url},
         }
     except Exception as e:
         return {"statusCode": 500, "msg": f"上传失败: {str(e)}"}
@@ -38,20 +34,38 @@ async def upload(
 @router.post("/oss/shudao/upload_image")
 async def upload_image(
     request: Request,
-    file: UploadFile = File(...)
+    file: Optional[UploadFile] = File(None),
+    image: Optional[UploadFile] = File(None),
 ):
-    """上传图片"""
+    """Upload an image to OSS.
+
+    Supports both the current `file` form field and the legacy `image` field
+    used by the existing frontend.
+    """
     user = request.state.user
     if not user:
         return {"statusCode": 401, "msg": "未授权"}
-    
+
     try:
-        content = await file.read()
-        file_url = oss_service.upload_image(content, file.filename)
+        upload_file = image or file
+        if not upload_file:
+            return {"statusCode": 422, "msg": "缺少图片文件"}
+
+        content = await upload_file.read()
+        file_url = oss_service.upload_image(content, upload_file.filename)
         return {
             "statusCode": 200,
             "msg": "上传成功",
-            "data": {"image_url": file_url}
+            "fileUrl": file_url,
+            "fileURL": file_url,
+            "fileName": upload_file.filename,
+            "fileSize": len(content),
+            "data": {
+                "image_url": file_url,
+                "file_url": file_url,
+                "file_name": upload_file.filename,
+                "file_size": len(content),
+            },
         }
     except Exception as e:
         return {"statusCode": 500, "msg": f"上传失败: {str(e)}"}
@@ -65,20 +79,20 @@ class UploadJsonRequest(BaseModel):
 @router.post("/oss/shudao/upload_json")
 async def upload_ppt_json(
     request: Request,
-    data: UploadJsonRequest
+    data: UploadJsonRequest,
 ):
-    """上传JSON文件 - 对齐Go版本函数名"""
+    """Upload JSON content to OSS."""
     user = request.state.user
     if not user:
         return {"statusCode": 401, "msg": "未授权"}
-    
+
     try:
         json_str = json.dumps(data.content, ensure_ascii=False)
         file_url = oss_service.upload_json(json_str, data.filename)
         return {
             "statusCode": 200,
             "msg": "上传成功",
-            "data": {"file_url": file_url}
+            "data": {"file_url": file_url},
         }
     except Exception as e:
         return {"statusCode": 500, "msg": f"上传失败: {str(e)}"}
@@ -86,17 +100,17 @@ async def upload_ppt_json(
 
 @router.get("/oss/parse")
 async def parse_oss(url: str, request: Request):
-    """OSS解析 - 对齐Go版本函数名"""
+    """Resolve an OSS proxy URL."""
     user = request.state.user
     if not user:
         return {"statusCode": 401, "msg": "未授权"}
-    
+
     try:
         decrypted_url = oss_service.parse_url(url)
         return {
             "statusCode": 200,
             "msg": "success",
-            "data": {"url": decrypted_url}
+            "data": {"url": decrypted_url},
         }
     except Exception as e:
         return {"statusCode": 500, "msg": f"解析失败: {str(e)}"}
@@ -105,26 +119,19 @@ async def parse_oss(url: str, request: Request):
 @router.get("/get_file_link")
 async def get_file_link(
     filename: str,
-    request: Request
+    request: Request,
 ):
-    """获取文件链接"""
+    """Get a signed OSS URL by filename."""
     user = request.state.user
     if not user:
         return {"statusCode": 401, "msg": "未授权"}
-    
+
     try:
         file_url = oss_service.get_signed_url(filename)
         return {
             "statusCode": 200,
             "msg": "success",
-            "data": {"file_url": file_url}
+            "data": {"file_url": file_url},
         }
     except Exception as e:
         return {"statusCode": 500, "msg": f"获取失败: {str(e)}"}
-
-
-# 以下路由已在 total.py / chat.py 中实现(含完整逻辑),此处不重复定义:
-# - GET  /download_file       → routers/total.py(流式代理下载OSS)
-# - POST /policy_file_count   → routers/total.py(view/download计数,字段 count_type)
-# - POST /save_ppt_outline    → routers/chat.py(更新AIMessage.content)
-# - POST /save_edit_document  → routers/chat.py(更新AIMessage.content)

+ 244 - 145
shudao-chat-py/routers/hazard.py

@@ -1,112 +1,215 @@
 """
-隐患识别路由
+Hazard detection routes.
 """
-from fastapi import APIRouter, Depends, Request, File, UploadFile
-from sqlalchemy.orm import Session
+
+from typing import Any, Dict, List, Optional
+import io
+import json
+import time
+
+import httpx
+from fastapi import APIRouter, Depends, Request
 from pydantic import BaseModel
-from typing import Optional
+from sqlalchemy.orm import Session
+from PIL import Image, ImageDraw, ImageFont
+
 from database import get_db
 from models.scene import RecognitionRecord
-from services.yolo_service import yolo_service
 from services.oss_service import oss_service
-from utils.logger import logger
+from services.yolo_service import yolo_service
 from utils.crypto import decrypt_url
-from PIL import Image, ImageDraw, ImageFont
-import io
-import httpx
-import time
-import math
-import json
+from utils.logger import logger
 
 router = APIRouter()
 
 
 class HazardRequest(BaseModel):
-    """隐患识别请求"""
-    image_url: str
+    """Compatible request model for old and new frontend payloads."""
+
+    image_url: Optional[str] = None
+    image: Optional[str] = None
     scene_type: str = ""
+    scene_name: str = ""
     user_name: str = ""
+    username: str = ""
     user_account: str = ""
+    account: str = ""
+    date: str = ""
 
 
 class SaveStepRequest(BaseModel):
-    """保存步骤请求"""
+    """Save current step for a recognition record."""
+
     record_id: int
     current_step: int
 
 
+SCENE_KEY_ALIASES = {
+    "tunnel": "tunnel",
+    "隧道": "tunnel",
+    "隧道施工": "tunnel",
+    "隧道工程": "tunnel",
+    "simple_supported_bridge": "simple_supported_bridge",
+    "bridge": "simple_supported_bridge",
+    "桥梁": "simple_supported_bridge",
+    "桥梁施工": "simple_supported_bridge",
+    "桥梁工程": "simple_supported_bridge",
+    "gas_station": "gas_station",
+    "加油站": "gas_station",
+    "special_equipment": "special_equipment",
+    "特种设备": "special_equipment",
+    "operate_highway": "operate_highway",
+    "运营高速公路": "operate_highway",
+}
+
+SCENE_DISPLAY_NAMES = {
+    "tunnel": "隧道工程",
+    "simple_supported_bridge": "桥梁工程",
+    "gas_station": "加油站",
+    "special_equipment": "特种设备",
+    "operate_highway": "运营高速公路",
+}
+
+
+def _get_user_code(user: Any) -> str:
+    return (
+        getattr(user, "userCode", None)
+        or getattr(user, "user_code", None)
+        or getattr(user, "account", "")
+    )
+
+
+def _resolve_scene_key(scene_value: str) -> str:
+    if not scene_value:
+        return ""
+    return SCENE_KEY_ALIASES.get(scene_value.strip(), scene_value.strip())
+
+
+def _unique_ordered(items: List[str]) -> List[str]:
+    seen = set()
+    ordered = []
+    for item in items:
+        if not item or item in seen:
+            continue
+        seen.add(item)
+        ordered.append(item)
+    return ordered
+
+
+def _build_frontend_result(hazards: List[Dict[str, Any]]) -> Dict[str, Any]:
+    raw_labels: List[str] = []
+    element_hazards: Dict[str, List[str]] = {}
+    detections: List[Dict[str, Any]] = []
+
+    for hazard in hazards:
+        label = str(hazard.get("label") or "").strip()
+        if not label:
+            continue
+
+        raw_labels.append(label)
+        element_hazards.setdefault(label, [])
+        if label not in element_hazards[label]:
+            element_hazards[label].append(label)
+
+        box = hazard.get("bbox") or hazard.get("box") or []
+        detections.append(
+            {
+                "label": label,
+                "box": box,
+                "bbox": box,
+                "confidence": hazard.get("confidence", 0),
+            }
+        )
+
+    display_labels = _unique_ordered(raw_labels)
+    return {
+        "display_labels": display_labels,
+        "labels": display_labels,
+        "third_scenes": display_labels,
+        "element_hazards": element_hazards,
+        "detections": detections,
+    }
+
+
 @router.post("/hazard")
 async def hazard(
     request: Request,
     data: HazardRequest,
-    db: Session = Depends(get_db)
+    db: Session = Depends(get_db),
 ):
-    """
-    隐患识别接口
-    流程:
-    1. 从 OSS 代理 URL 解密获取真实 URL
-    2. 下载图片到内存
-    3. 调用 YOLO 服务识别
-    4. 绘制边界框 + 水印(用户名/账号/日期)
-    5. 上传结果图片到 OSS
-    6. 插入 RecognitionRecord
-    7. 返回结果
-    """
+    """Run hazard detection and return a frontend-compatible payload."""
+
     user = request.state.user
     if not user:
         return {"statusCode": 401, "msg": "未授权"}
-    
+
     try:
-        # 1. 解密 OSS URL
+        source_image_url = data.image_url or data.image
+        if not source_image_url:
+            return {"statusCode": 422, "msg": "image_url 不能为空"}
+
+        scene_key = _resolve_scene_key(data.scene_type or data.scene_name)
+        user_code = _get_user_code(user)
+        user_name = (
+            data.user_name
+            or data.username
+            or getattr(user, "name", None)
+            or getattr(user, "username", None)
+            or getattr(user, "account", "")
+        )
+        user_account = (
+            data.user_account
+            or data.account
+            or getattr(user, "account", "")
+        )
+
         try:
-            real_image_url = decrypt_url(data.image_url)
-        except:
-            # 如果解密失败,可能是直接的 URL
-            real_image_url = data.image_url
-        
-        # 2. 下载图片到内存
+            real_image_url = decrypt_url(source_image_url)
+        except Exception:
+            real_image_url = source_image_url
+
         async with httpx.AsyncClient(timeout=30.0) as client:
             img_response = await client.get(real_image_url)
             img_response.raise_for_status()
             image_bytes = img_response.content
-        
-        # 3. 调用 YOLO 服务识别
-        # 先上传图片到临时位置,或者传递 URL
-        yolo_result = await yolo_service.detect_hazards(real_image_url, data.scene_type)
-        
-        hazards = yolo_result.get("hazards", [])
+
+        yolo_result = await yolo_service.detect_hazards(real_image_url, scene_key)
+        hazards = yolo_result.get("hazards", []) or []
         hazard_count = len(hazards)
-        
-        # 4. 绘制边界框和水印
+        frontend_result = _build_frontend_result(hazards)
+        current_ts = int(time.time())
+
         result_image_bytes = await _draw_boxes_and_watermark(
             image_bytes,
             hazards,
-            user_name=data.user_name or user.account,
-            user_account=user.account,
+            user_name=user_name,
+            user_account=user_account,
         )
-        
-        # 5. 上传结果图片到 OSS
-        result_filename = f"hazard_detection/{user.userCode}/{int(time.time())}.jpg"
+
+        result_filename = f"hazard_detection/{user_code}/{current_ts}.jpg"
         result_url = await oss_service.upload_bytes(result_image_bytes, result_filename)
-        
-        # 6. 插入 RecognitionRecord
+
+        scene_display_name = SCENE_DISPLAY_NAMES.get(scene_key, scene_key or "隐患提示")
         record = RecognitionRecord(
-            user_id=user.userCode,
-            scene_type=data.scene_type,
-            original_image_url=data.image_url,
+            user_id=user_code,
+            scene_type=scene_key,
+            original_image_url=source_image_url,
             recognition_image_url=result_url,
             hazard_count=hazard_count,
             hazard_details=json.dumps(hazards, ensure_ascii=False),
             current_step=1,
-            created_at=int(time.time()),
-            updated_at=int(time.time()),
-            is_deleted=0
+            title=f"{scene_display_name}隐患提示",
+            description=" ".join(frontend_result["third_scenes"]),
+            labels=",".join(frontend_result["display_labels"]),
+            tag_type=scene_key,
+            created_at=current_ts,
+            updated_at=current_ts,
+            is_deleted=0,
         )
         db.add(record)
         db.commit()
         db.refresh(record)
-        
-        # 7. 返回结果
+
         return {
             "statusCode": 200,
             "msg": "识别成功",
@@ -114,16 +217,25 @@ async def hazard(
                 "record_id": record.id,
                 "hazard_count": hazard_count,
                 "hazards": hazards,
+                "scene_name": scene_key,
+                "annotated_image": result_url,
+                "display_labels": frontend_result["display_labels"],
+                "labels": frontend_result["labels"],
+                "third_scenes": frontend_result["third_scenes"],
+                "element_hazards": frontend_result["element_hazards"],
+                "detections": frontend_result["detections"],
                 "result_image_url": result_url,
-                "original_image_url": data.image_url
-            }
+                "original_image_url": source_image_url,
+            },
         }
-    
+
     except httpx.HTTPError as e:
         logger.error(f"[hazard] 图片下载失败: {e}")
+        db.rollback()
         return {"statusCode": 500, "msg": f"图片下载失败: {str(e)}"}
     except Exception as e:
         logger.error(f"[hazard] 处理异常: {e}")
+        db.rollback()
         return {"statusCode": 500, "msg": f"处理失败: {str(e)}"}
 
 
@@ -131,40 +243,43 @@ async def hazard(
 async def save_step(
     request: Request,
     data: SaveStepRequest,
-    db: Session = Depends(get_db)
+    db: Session = Depends(get_db),
 ):
-    """
-    保存识别步骤
-    更新 RecognitionRecord.current_step
-    """
+    """Update RecognitionRecord.current_step."""
+
     user = request.state.user
     if not user:
         return {"statusCode": 401, "msg": "未授权"}
-    
+
     try:
-        # 更新步骤
-        affected = db.query(RecognitionRecord).filter(
-            RecognitionRecord.id == data.record_id,
-            RecognitionRecord.user_id == user.userCode
-        ).update({
-            "current_step": data.current_step,
-            "updated_at": int(time.time())
-        })
-        
+        affected = (
+            db.query(RecognitionRecord)
+            .filter(
+                RecognitionRecord.id == data.record_id,
+                RecognitionRecord.user_id == _get_user_code(user),
+            )
+            .update(
+                {
+                    "current_step": data.current_step,
+                    "updated_at": int(time.time()),
+                }
+            )
+        )
+
         if affected == 0:
             return {"statusCode": 404, "msg": "记录不存在"}
-        
+
         db.commit()
-        
+
         return {
             "statusCode": 200,
             "msg": "保存成功",
             "data": {
                 "record_id": data.record_id,
-                "current_step": data.current_step
-            }
+                "current_step": data.current_step,
+            },
         }
-    
+
     except Exception as e:
         logger.error(f"[save_step] 异常: {e}")
         db.rollback()
@@ -173,109 +288,93 @@ async def save_step(
 
 async def _draw_boxes_and_watermark(
     image_bytes: bytes,
-    hazards: list,
+    hazards: List[Dict[str, Any]],
     user_name: str,
-    user_account: str
+    user_account: str,
 ) -> bytes:
-    """
-    在图片上绘制边界框和水印(对齐Go版本)
-    
-    功能:
-    1. 绘制检测边界框
-    2. 添加45度角水印(用户名、账号、日期)
-    
-    Args:
-        image_bytes: 原始图片字节
-        hazards: YOLO 检测结果列表,每项包含 bbox, label, confidence
-        user_name: 用户名
-        user_account: 用户账号
-    
-    Returns:
-        处理后的图片字节
-    """
+    """Draw detection boxes and a tiled watermark on the image."""
+
     try:
-        # 打开图片
         image = Image.open(io.BytesIO(image_bytes)).convert("RGBA")
         width, height = image.size
-        
-        # 创建透明图层用于绘制
+
         overlay = Image.new("RGBA", (width, height), (255, 255, 255, 0))
         draw = ImageDraw.Draw(overlay)
-        
-        # 尝试加载字体
+
         try:
-            font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 20)
-            font_small = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14)
-        except:
+            font = ImageFont.truetype(
+                "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 20
+            )
+            font_small = ImageFont.truetype(
+                "/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14
+            )
+        except Exception:
             try:
-                # Windows字体路径
                 font = ImageFont.truetype("C:/Windows/Fonts/msyh.ttc", 20)
                 font_small = ImageFont.truetype("C:/Windows/Fonts/msyh.ttc", 14)
-            except:
+            except Exception:
                 font = ImageFont.load_default()
                 font_small = ImageFont.load_default()
-        
-        # 1. 绘制边界框
+
         for hazard in hazards:
-            bbox = hazard.get("bbox", [])
+            bbox = hazard.get("bbox", []) or hazard.get("box", [])
             label = hazard.get("label", "")
             confidence = hazard.get("confidence", 0)
-            
+
             if len(bbox) == 4:
                 x1, y1, x2, y2 = bbox
-                # 绘制矩形框(红色)
                 draw.rectangle([x1, y1, x2, y2], outline=(255, 0, 0, 255), width=3)
-                # 绘制标签
                 text = f"{label} {confidence:.2f}"
-                draw.text((x1, max(0, y1 - 25)), text, fill=(255, 0, 0, 255), font=font)
-        
-        # 2. 添加45度角水印(对齐Go版本)
+                draw.text(
+                    (x1, max(0, y1 - 25)),
+                    text,
+                    fill=(255, 0, 0, 255),
+                    font=font,
+                )
+
         current_date = time.strftime("%Y/%m/%d")
-        watermarks = [user_name, user_account, current_date]
-        
-        # 水印参数
+        watermarks = [user_name or "", user_account or "", current_date]
+        watermarks = [text for text in watermarks if text]
+        if not watermarks:
+            watermarks = [current_date]
+
         text_height_estimate = 50
         text_width_estimate = 150
         angle = 45
-        
-        # 创建水印文本图层
-        watermark_layer = Image.new("RGBA", (width * 2, height * 2), (255, 255, 255, 0))
+
+        watermark_layer = Image.new(
+            "RGBA", (width * 2, height * 2), (255, 255, 255, 0)
+        )
         watermark_draw = ImageDraw.Draw(watermark_layer)
-        
-        # 45度角平铺水印
+
         for y in range(-height, height * 2, text_height_estimate):
             for x in range(-width, width * 2, text_width_estimate):
-                # 计算当前行使用哪个水印文本
                 row_index = int(y / text_height_estimate) % len(watermarks)
-                text = watermarks[row_index]
-                
-                # 使用更深的灰色(对齐Go版本)
                 watermark_draw.text(
                     (x, y),
-                    text,
-                    fill=(128, 128, 128, 60),  # 半透明灰色
-                    font=font_small
+                    watermarks[row_index],
+                    fill=(128, 128, 128, 60),
+                    font=font_small,
                 )
-        
-        # 旋转水印层
-        watermark_layer = watermark_layer.rotate(angle, expand=False, fillcolor=(255, 255, 255, 0))
-        
-        # 裁剪到原始尺寸
+
+        watermark_layer = watermark_layer.rotate(
+            angle, expand=False, fillcolor=(255, 255, 255, 0)
+        )
+
         crop_x = (watermark_layer.width - width) // 2
         crop_y = (watermark_layer.height - height) // 2
-        watermark_layer = watermark_layer.crop((crop_x, crop_y, crop_x + width, crop_y + height))
-        
-        # 合并图层
+        watermark_layer = watermark_layer.crop(
+            (crop_x, crop_y, crop_x + width, crop_y + height)
+        )
+
         image = Image.alpha_composite(image, watermark_layer)
         image = Image.alpha_composite(image, overlay)
-        
-        # 转换为RGB并保存
+
         final_image = image.convert("RGB")
         output = io.BytesIO()
         final_image.save(output, format="JPEG", quality=95)
         return output.getvalue()
-    
+
     except Exception as e:
         logger.error(f"[_draw_boxes_and_watermark] 图片处理失败: {e}")
-        # 如果处理失败,返回原图
         return image_bytes

+ 319 - 184
shudao-chat-py/routers/scene.py

@@ -1,86 +1,211 @@
+import json
+import time
+from typing import Any, Optional
+
 from fastapi import APIRouter, Depends, Request
-from sqlalchemy.orm import Session
 from pydantic import BaseModel
-from typing import Optional
+from sqlalchemy.orm import Session
+
 from database import get_db
-from models.scene import Scene, FirstScene, SecondScene, ThirdScene, RecognitionRecord, SceneTemplate
-import time
+from models.scene import (
+    FirstScene,
+    RecognitionRecord,
+    Scene,
+    SceneTemplate,
+    SecondScene,
+    ThirdScene,
+)
 
 router = APIRouter()
 
 
+def _get_user_code(user: Any) -> str:
+    return (
+        getattr(user, "userCode", None)
+        or getattr(user, "user_code", None)
+        or getattr(user, "account", "")
+    )
+
+
+def _load_hazard_details(record: RecognitionRecord):
+    if not record.hazard_details:
+        return []
+    try:
+        data = json.loads(record.hazard_details)
+        return data if isinstance(data, list) else []
+    except Exception:
+        return []
+
+
+def _split_labels(labels):
+    if not labels:
+        return []
+    if isinstance(labels, list):
+        return [str(item).strip() for item in labels if str(item).strip()]
+    return [
+        item.strip()
+        for item in str(labels).replace(",", ",").split(",")
+        if item.strip()
+    ]
+
+
+def _unique_ordered(items):
+    seen = set()
+    ordered = []
+    for item in items:
+        if not item or item in seen:
+            continue
+        seen.add(item)
+        ordered.append(item)
+    return ordered
+
+
+def _build_record_view(record: RecognitionRecord):
+    hazard_details = _load_hazard_details(record)
+    derived_labels = _unique_ordered(
+        [
+            str(item.get("label") or "").strip()
+            for item in hazard_details
+            if str(item.get("label") or "").strip()
+        ]
+    )
+    display_labels = _split_labels(record.labels) or derived_labels
+
+    if record.description:
+        third_scenes = [item for item in str(record.description).split(" ") if item]
+    else:
+        third_scenes = derived_labels
+
+    detections = [
+        {
+            "label": item.get("label", ""),
+            "box": item.get("bbox") or item.get("box") or [],
+            "bbox": item.get("bbox") or item.get("box") or [],
+            "confidence": item.get("confidence", 0),
+        }
+        for item in hazard_details
+    ]
+
+    return {
+        "id": record.id,
+        "title": record.title or "隐患提示记录",
+        "description": record.description or " ".join(third_scenes),
+        "original_image_url": record.original_image_url,
+        "recognition_image_url": record.recognition_image_url,
+        "labels": record.labels or ",".join(display_labels),
+        "display_labels": display_labels,
+        "third_scenes": third_scenes,
+        "tag_type": record.tag_type or record.scene_type,
+        "scene_type": record.scene_type,
+        "effect_evaluation": record.effect_evaluation,
+        "hazard_details": hazard_details,
+        "detections": detections,
+    }
+
+
+def _resolve_record_id(
+    recognition_id: Optional[int] = None,
+    recognition_record_id: Optional[int] = None,
+):
+    return recognition_id or recognition_record_id
+
+
 @router.get("/get_scene_list")
 async def get_scene_list(db: Session = Depends(get_db)):
-    """获取场景列表"""
     scenes = db.query(Scene).filter(Scene.is_deleted == 0).all()
     return {
         "statusCode": 200,
         "msg": "success",
-        "data": [{"id": s.id, "scene_name": s.scene_name, "scene_en_name": s.scene_en_name} for s in scenes]
+        "data": [
+            {
+                "id": s.id,
+                "scene_name": s.scene_name,
+                "scene_en_name": s.scene_en_name,
+            }
+            for s in scenes
+        ],
     }
 
 
 @router.get("/get_first_scene_list")
 async def get_first_scene_list(scene_id: int, db: Session = Depends(get_db)):
-    """获取一级场景列表"""
-    scenes = db.query(FirstScene).filter(
-        FirstScene.scene_id == scene_id,
-        FirstScene.is_deleted == 0
-    ).all()
+    scenes = (
+        db.query(FirstScene)
+        .filter(FirstScene.scene_id == scene_id, FirstScene.is_deleted == 0)
+        .all()
+    )
     return {
         "statusCode": 200,
         "msg": "success",
-        "data": [{"id": s.id, "first_scene_name": s.first_scene_name} for s in scenes]
+        "data": [{"id": s.id, "first_scene_name": s.first_scene_name} for s in scenes],
     }
 
 
 @router.get("/get_second_scene_list")
-async def get_second_scene_list(first_scene_id: int, db: Session = Depends(get_db)):
-    """获取二级场景列表"""
-    scenes = db.query(SecondScene).filter(
-        SecondScene.first_scene_id == first_scene_id,
-        SecondScene.is_deleted == 0
-    ).all()
+async def get_second_scene_list(
+    first_scene_id: int, db: Session = Depends(get_db)
+):
+    scenes = (
+        db.query(SecondScene)
+        .filter(
+            SecondScene.first_scene_id == first_scene_id,
+            SecondScene.is_deleted == 0,
+        )
+        .all()
+    )
     return {
         "statusCode": 200,
         "msg": "success",
-        "data": [{"id": s.id, "second_scene_name": s.second_scene_name} for s in scenes]
+        "data": [{"id": s.id, "second_scene_name": s.second_scene_name} for s in scenes],
     }
 
 
 @router.get("/get_third_scene_list")
-async def get_third_scene_list(second_scene_id: int, db: Session = Depends(get_db)):
-    """获取三级场景列表"""
-    scenes = db.query(ThirdScene).filter(
-        ThirdScene.second_scene_id == second_scene_id,
-        ThirdScene.is_deleted == 0
-    ).all()
+async def get_third_scene_list(
+    second_scene_id: int, db: Session = Depends(get_db)
+):
+    scenes = (
+        db.query(ThirdScene)
+        .filter(
+            ThirdScene.second_scene_id == second_scene_id,
+            ThirdScene.is_deleted == 0,
+        )
+        .all()
+    )
     return {
         "statusCode": 200,
         "msg": "success",
-        "data": [{
-            "id": s.id,
-            "third_scene_name": s.third_scene_name,
-            "correct_example_image": s.correct_example_image,
-            "wrong_example_image": s.wrong_example_image
-        } for s in scenes]
+        "data": [
+            {
+                "id": s.id,
+                "third_scene_name": s.third_scene_name,
+                "correct_example_image": s.correct_example_image,
+                "wrong_example_image": s.wrong_example_image,
+            }
+            for s in scenes
+        ],
     }
 
 
 @router.get("/get_third_scene_example_image")
-async def get_third_scene_example_image(third_scene_name: str, db: Session = Depends(get_db)):
-    """获取三级场景示例图"""
+async def get_third_scene_example_image(
+    third_scene_name: str, db: Session = Depends(get_db)
+):
     if not third_scene_name:
         return {"statusCode": 400, "msg": "三级场景名称不能为空"}
-    
-    scene = db.query(ThirdScene).filter(
-        ThirdScene.third_scene_name == third_scene_name,
-        ThirdScene.is_deleted == 0
-    ).first()
-    
+
+    scene = (
+        db.query(ThirdScene)
+        .filter(
+            ThirdScene.third_scene_name == third_scene_name,
+            ThirdScene.is_deleted == 0,
+        )
+        .first()
+    )
+
     if not scene:
         return {"statusCode": 404, "msg": "三级场景不存在"}
-    
+
     return {
         "statusCode": 200,
         "msg": "success",
@@ -88,103 +213,122 @@ async def get_third_scene_example_image(third_scene_name: str, db: Session = Dep
             "id": scene.id,
             "third_scene_name": scene.third_scene_name,
             "correct_example_image": scene.correct_example_image,
-            "wrong_example_image": scene.wrong_example_image
-        }
+            "wrong_example_image": scene.wrong_example_image,
+        },
     }
 
 
 @router.get("/get_history_recognition_record")
-async def get_history_recognition_record(request: Request, db: Session = Depends(get_db)):
-    """获取隐患识别历史记录"""
+async def get_history_recognition_record(
+    request: Request, db: Session = Depends(get_db)
+):
     user = request.state.user
     if not user:
         return {"statusCode": 401, "msg": "未授权"}
-    
-    # 获取所有记录(不限制数量)
-    records = db.query(RecognitionRecord).filter(
-        RecognitionRecord.user_id == user.userCode,
-        RecognitionRecord.is_deleted == 0
-    ).order_by(RecognitionRecord.updated_at.desc()).all()
-    
-    # 计算总数
-    total = db.query(RecognitionRecord).filter(
-        RecognitionRecord.user_id == user.userCode,
-        RecognitionRecord.is_deleted == 0
-    ).count()
-    
+
+    user_code = _get_user_code(user)
+    records = (
+        db.query(RecognitionRecord)
+        .filter(RecognitionRecord.user_id == user_code, RecognitionRecord.is_deleted == 0)
+        .order_by(RecognitionRecord.updated_at.desc())
+        .all()
+    )
+
+    total = (
+        db.query(RecognitionRecord)
+        .filter(RecognitionRecord.user_id == user_code, RecognitionRecord.is_deleted == 0)
+        .count()
+    )
+
     return {
         "statusCode": 200,
         "msg": "success",
-        "data": [{
-            "id": r.id,
-            "title": r.title,
-            "original_image_url": r.original_image_url,
-            "recognition_image_url": r.recognition_image_url,
-            "labels": r.labels,
-            "created_at": r.created_at
-        } for r in records],
-        "total": total
+        "data": [
+            {
+                **_build_record_view(record),
+                "created_at": record.created_at,
+            }
+            for record in records
+        ],
+        "total": total,
     }
 
 
 @router.get("/get_recognition_record_detail")
-async def get_recognition_record_detail(recognition_id: int, db: Session = Depends(get_db)):
-    """获取识别记录详情"""
-    record = db.query(RecognitionRecord).filter(
-        RecognitionRecord.id == recognition_id,
-        RecognitionRecord.is_deleted == 0
-    ).first()
+async def get_recognition_record_detail(
+    recognition_id: Optional[int] = None,
+    recognition_record_id: Optional[int] = None,
+    db: Session = Depends(get_db),
+):
+    record_id = _resolve_record_id(recognition_id, recognition_record_id)
+    if not record_id:
+        return {"statusCode": 422, "msg": "recognition_id 不能为空"}
+
+    record = (
+        db.query(RecognitionRecord)
+        .filter(RecognitionRecord.id == record_id, RecognitionRecord.is_deleted == 0)
+        .first()
+    )
     if not record:
         return {"statusCode": 404, "msg": "记录不存在"}
-    
-    # 将 Description 字符串转换为数组
-    third_scenes = []
-    if record.description:
-        third_scenes = record.description.split(" ")
-    
+
+    record_view = _build_record_view(record)
     return {
         "statusCode": 200,
         "msg": "success",
         "data": {
             "id": record.id,
             "user_id": record.user_id,
-            "title": record.title,
-            "description": record.description,
+            "title": record_view["title"],
+            "description": record_view["description"],
             "original_image_url": record.original_image_url,
             "recognition_image_url": record.recognition_image_url,
-            "labels": record.labels,
-            "third_scenes": third_scenes,
-            "tag_type": record.tag_type,
+            "labels": record_view["labels"],
+            "display_labels": record_view["display_labels"],
+            "third_scenes": record_view["third_scenes"],
+            "tag_type": record_view["tag_type"],
+            "scene_type": record.scene_type,
             "scene_match": record.scene_match,
             "tip_accuracy": record.tip_accuracy,
             "effect_evaluation": record.effect_evaluation,
             "user_remark": record.user_remark,
+            "hazard_details": record_view["hazard_details"],
+            "detections": record_view["detections"],
             "created_at": record.created_at,
-            "updated_at": record.updated_at
-        }
+            "updated_at": record.updated_at,
+        },
     }
 
 
 class DeleteRecognitionRequest(BaseModel):
-    recognition_id: int
+    recognition_id: Optional[int] = None
+    recognition_record_id: Optional[int] = None
 
 
 @router.post("/delete_recognition_record")
-async def delete_recognition_record(data: DeleteRecognitionRequest, request: Request, db: Session = Depends(get_db)):
-    """删除识别记录(软删除)"""
+async def delete_recognition_record(
+    data: DeleteRecognitionRequest,
+    request: Request,
+    db: Session = Depends(get_db),
+):
     user = request.state.user
     if not user:
         return {"statusCode": 401, "msg": "未授权"}
-    
-    db.query(RecognitionRecord).filter(
-        RecognitionRecord.id == data.recognition_id,
-        RecognitionRecord.user_id == user.userCode
-    ).update({
-        "is_deleted": 1,
-        "deleted_at": int(time.time())
-    })
+
+    record_id = _resolve_record_id(data.recognition_id, data.recognition_record_id)
+    if not record_id:
+        return {"statusCode": 422, "msg": "recognition_id 不能为空"}
+
+    (
+        db.query(RecognitionRecord)
+        .filter(
+            RecognitionRecord.id == record_id,
+            RecognitionRecord.user_id == _get_user_code(user),
+        )
+        .update({"is_deleted": 1, "deleted_at": int(time.time())})
+    )
     db.commit()
-    
+
     return {"statusCode": 200, "msg": "删除成功"}
 
 
@@ -198,16 +342,15 @@ class EvaluationRequest(BaseModel):
 
 @router.post("/submit_evaluation")
 async def submit_evaluation(data: EvaluationRequest, db: Session = Depends(get_db)):
-    """提交点评"""
-    record = db.query(RecognitionRecord).filter(
-        RecognitionRecord.id == data.id,
-        RecognitionRecord.is_deleted == 0
-    ).first()
-    
+    record = (
+        db.query(RecognitionRecord)
+        .filter(RecognitionRecord.id == data.id, RecognitionRecord.is_deleted == 0)
+        .first()
+    )
+
     if not record:
         return {"statusCode": 404, "msg": "记录不存在"}
-    
-    # 更新评价字段
+
     if data.scene_match is not None:
         record.scene_match = data.scene_match
     if data.tip_accuracy is not None:
@@ -216,35 +359,38 @@ async def submit_evaluation(data: EvaluationRequest, db: Session = Depends(get_d
         record.effect_evaluation = data.effect_evaluation
     if data.user_remark is not None:
         record.user_remark = data.user_remark
-    
+
     record.updated_at = int(time.time())
     db.commit()
-    
+
     return {"statusCode": 200, "msg": "success"}
 
 
 @router.get("/get_latest_recognition_record")
-async def get_latest_recognition_record(request: Request, db: Session = Depends(get_db)):
-    """获取最新识别记录"""
+async def get_latest_recognition_record(
+    request: Request, db: Session = Depends(get_db)
+):
     user = request.state.user
     if not user:
         return {"statusCode": 401, "msg": "未授权"}
-    
-    record = db.query(RecognitionRecord).filter(
-        RecognitionRecord.user_id == user.userCode,
-        RecognitionRecord.is_deleted == 0
-    ).order_by(RecognitionRecord.created_at.desc()).first()
-    
-    # 如果数据为空,则构建一个假数据 effect_evaluation=1 给前端
+
+    record = (
+        db.query(RecognitionRecord)
+        .filter(
+            RecognitionRecord.user_id == _get_user_code(user),
+            RecognitionRecord.is_deleted == 0,
+        )
+        .order_by(RecognitionRecord.created_at.desc())
+        .first()
+    )
+
     if not record:
         return {
             "statusCode": 200,
             "msg": "success",
-            "data": {
-                "effect_evaluation": 1
-            }
+            "data": {"effect_evaluation": 1},
         }
-    
+
     return {
         "statusCode": 200,
         "msg": "success",
@@ -255,17 +401,12 @@ async def get_latest_recognition_record(request: Request, db: Session = Depends(
             "recognition_image_url": record.recognition_image_url,
             "labels": record.labels,
             "created_at": record.created_at,
-            "effect_evaluation": record.effect_evaluation
-        }
+            "effect_evaluation": record.effect_evaluation,
+        },
     }
 
 
-# ============================================================
-# 场景模板接口(对齐Go版本)
-# ============================================================
-
 class SceneTemplateCreate(BaseModel):
-    """创建场景模板请求"""
     scene_name: str
     scene_type: str
     scene_desc: str = ""
@@ -273,8 +414,9 @@ class SceneTemplateCreate(BaseModel):
 
 
 @router.post("/scene_template")
-async def create_scene_template(data: SceneTemplateCreate, db: Session = Depends(get_db)):
-    """创建场景模板"""
+async def create_scene_template(
+    data: SceneTemplateCreate, db: Session = Depends(get_db)
+):
     template = SceneTemplate(
         scene_name=data.scene_name,
         scene_type=data.scene_type,
@@ -282,16 +424,16 @@ async def create_scene_template(data: SceneTemplateCreate, db: Session = Depends
         model_name=data.model_name,
         created_at=int(time.time()),
         updated_at=int(time.time()),
-        is_deleted=0
+        is_deleted=0,
     )
     db.add(template)
     db.commit()
     db.refresh(template)
-    
+
     return {
         "statusCode": 200,
         "msg": "创建成功",
-        "data": {"id": template.id}
+        "data": {"id": template.id},
     }
 
 
@@ -299,25 +441,22 @@ async def create_scene_template(data: SceneTemplateCreate, db: Session = Depends
 async def get_scene_templates(
     page: int = 1,
     page_size: int = 20,
-    db: Session = Depends(get_db)
+    db: Session = Depends(get_db),
 ):
-    """获取场景模板列表(分页)"""
-    # 限制page_size最大值
     if page_size > 100:
         page_size = 100
-    
+
     offset = (page - 1) * page_size
-    
-    # 查询总数
-    total = db.query(SceneTemplate).filter(
-        SceneTemplate.is_deleted == 0
-    ).count()
-    
-    # 查询列表
-    templates = db.query(SceneTemplate).filter(
-        SceneTemplate.is_deleted == 0
-    ).order_by(SceneTemplate.created_at.desc()).offset(offset).limit(page_size).all()
-    
+    total = db.query(SceneTemplate).filter(SceneTemplate.is_deleted == 0).count()
+    templates = (
+        db.query(SceneTemplate)
+        .filter(SceneTemplate.is_deleted == 0)
+        .order_by(SceneTemplate.created_at.desc())
+        .offset(offset)
+        .limit(page_size)
+        .all()
+    )
+
     return {
         "statusCode": 200,
         "msg": "success",
@@ -325,16 +464,16 @@ async def get_scene_templates(
             "total": total,
             "items": [
                 {
-                    "id": t.id,
-                    "scene_name": t.scene_name,
-                    "scene_type": t.scene_type,
-                    "scene_desc": t.scene_desc,
-                    "model_name": t.model_name,
-                    "created_at": t.created_at
+                    "id": template.id,
+                    "scene_name": template.scene_name,
+                    "scene_type": template.scene_type,
+                    "scene_desc": template.scene_desc,
+                    "model_name": template.model_name,
+                    "created_at": template.created_at,
                 }
-                for t in templates
-            ]
-        }
+                for template in templates
+            ],
+        },
     }
 
 
@@ -344,36 +483,32 @@ async def get_recognition_records(
     scene_type: str = "",
     page: int = 1,
     page_size: int = 20,
-    db: Session = Depends(get_db)
+    db: Session = Depends(get_db),
 ):
-    """获取识别记录列表(分页+筛选)- 符合REST规范"""
     user = request.state.user
     if not user:
         return {"statusCode": 401, "msg": "未授权"}
-    
-    # 限制page_size最大值
+
     if page_size > 100:
         page_size = 100
-    
-    # 构建查询条件
+
     query = db.query(RecognitionRecord).filter(
-        RecognitionRecord.user_id == user.userCode,
-        RecognitionRecord.is_deleted == 0
+        RecognitionRecord.user_id == _get_user_code(user),
+        RecognitionRecord.is_deleted == 0,
     )
-    
-    # 场景类型筛选
+
     if scene_type:
         query = query.filter(RecognitionRecord.scene_type == scene_type)
-    
-    # 查询总数
+
     total = query.count()
-    
-    # 分页查询
     offset = (page - 1) * page_size
-    records = query.order_by(
-        RecognitionRecord.created_at.desc()
-    ).offset(offset).limit(page_size).all()
-    
+    records = (
+        query.order_by(RecognitionRecord.created_at.desc())
+        .offset(offset)
+        .limit(page_size)
+        .all()
+    )
+
     return {
         "statusCode": 200,
         "msg": "success",
@@ -381,15 +516,15 @@ async def get_recognition_records(
             "total": total,
             "items": [
                 {
-                    "id": r.id,
-                    "scene_type": r.scene_type,
-                    "original_image_url": r.original_image_url,
-                    "result_image_url": r.recognition_image_url,
-                    "hazard_count": r.hazard_count,
-                    "current_step": r.current_step,
-                    "created_at": r.created_at
+                    "id": record.id,
+                    "scene_type": record.scene_type,
+                    "original_image_url": record.original_image_url,
+                    "result_image_url": record.recognition_image_url,
+                    "hazard_count": record.hazard_count,
+                    "current_step": record.current_step,
+                    "created_at": record.created_at,
                 }
-                for r in records
-            ]
-        }
+                for record in records
+            ],
+        },
     }

+ 23 - 6
shudao-chat-py/routers/total.py

@@ -67,9 +67,16 @@ async def get_policy_file(
 
 
 @router.get("/get_function_card")
-async def get_function_card(db: Session = Depends(get_db)):
+async def get_function_card(
+    function_type: Optional[int] = None,
+    db: Session = Depends(get_db)
+):
     """获取功能卡片"""
-    cards = db.query(FunctionCard).limit(4).all()
+    query = db.query(FunctionCard).filter(FunctionCard.is_deleted == 0)
+    if function_type is not None:
+        query = query.filter(FunctionCard.function_type == function_type)
+
+    cards = query.order_by(FunctionCard.id.asc()).limit(4).all()
     return {
         "statusCode": 200,
         "msg": "success",
@@ -87,10 +94,19 @@ async def get_function_card(db: Session = Depends(get_db)):
 
 
 @router.get("/get_hot_question")
-async def get_hot_question(db: Session = Depends(get_db)):
+async def get_hot_question(
+    question_type: Optional[int] = None,
+    db: Session = Depends(get_db)
+):
     """获取热点问题(按点击量排序)"""
-    questions = db.query(HotQuestion).order_by(
-        HotQuestion.click_count.desc()).limit(3).all()
+    query = db.query(HotQuestion).filter(HotQuestion.is_deleted == 0)
+    if question_type is not None:
+        query = query.filter(HotQuestion.question_type == question_type)
+
+    questions = query.order_by(
+        HotQuestion.click_count.desc(),
+        HotQuestion.id.asc()
+    ).limit(3).all()
     return {
         "statusCode": 200,
         "msg": "success",
@@ -98,7 +114,8 @@ async def get_hot_question(db: Session = Depends(get_db)):
             {
                 "id": q.id,
                 "question": q.question,
-                "click_count": q.click_count or 0
+                "click_count": q.click_count or 0,
+                "question_type": q.question_type
             }
             for q in questions
         ]

+ 1 - 1
shudao-chat-py/services/aichat_proxy.py

@@ -80,7 +80,7 @@ class AIChatProxy:
                 yield b"data: {\"type\": \"completed\"}\n\n"
             except Exception as e:
                 logger.error(f"[AIChat代理] SSE 请求异常: {e}")
-                yield f'data: {{"type": "online_error", "message": "AIChat服务请求超时"}}\n\n'.encode('utf-8')
+                error_msg = f'data: {{"type": "online_error", "message": "AIChat服务请求异常"}}\n\n'
                 yield error_msg.encode('utf-8')
                 yield b"data: {\"type\": \"completed\"}\n\n"
 

+ 1 - 1
shudao-chat-py/services/deepseek_service.py

@@ -28,7 +28,7 @@ class DeepSeekService:
         }
         
         try:
-            async with httpx.AsyncClient(timeout=60.0) as client:
+            async with httpx.AsyncClient(timeout=120.0) as client:
                 response = await client.post(
                     f"{self.base_url}/chat/completions",
                     headers=headers,

+ 37 - 3
shudao-chat-py/services/qwen_service.py

@@ -7,6 +7,7 @@ from typing import AsyncGenerator
 from utils.config import settings
 from utils.logger import logger
 from utils.prompt_loader import load_prompt
+from services.deepseek_service import deepseek_service
 
 
 class QwenService:
@@ -20,6 +21,18 @@ class QwenService:
         intent_base_url = settings.intent.api_url.rstrip('/')
         self.intent_api_url = f"{intent_base_url}/v1/chat/completions"
         self.intent_model = settings.intent.model
+
+    def _should_fallback(self, status_code: int) -> bool:
+        return status_code in (429, 500, 502, 503, 504)
+
+    async def _fallback_deepseek(self, messages: list) -> str:
+        try:
+            logger.warning("[Qwen API] Falling back to DeepSeek due to upstream error")
+            return await deepseek_service.chat(messages)
+        except Exception as e:
+            error_msg = str(e).strip() if str(e).strip() else type(e).__name__
+            logger.error(f"[Qwen API] DeepSeek fallback failed: {type(e).__name__}: {error_msg}")
+            raise RuntimeError(f"AI服务暂时不可用,主模型和备用模型均无法响应({type(e).__name__}),请稍后重试") from e
     
     async def extract_keywords(self, question: str) -> str:
         """从问题中提炼搜索关键词"""
@@ -103,6 +116,8 @@ class QwenService:
         
         # 使用指定的API URL,默认使用qwen3的URL
         target_url = api_url or self.api_url
+        normalized_target = target_url.rstrip("/")
+        is_qwen3_target = normalized_target == self.api_url.rstrip("/")
         
         # 详细请求日志
         logger.info(f"[Qwen API] 请求 URL: {target_url}")
@@ -116,16 +131,16 @@ class QwenService:
             }
             
             # 如果配置中有 token,添加到请求头(兼容需要认证的场景)
-            if hasattr(settings, 'intent') and hasattr(settings.intent, 'token') and api_url == self.intent_api_url:
+            if hasattr(settings, 'intent') and hasattr(settings.intent, 'token') and normalized_target == self.intent_api_url.rstrip("/"):
                 if settings.intent.token:
                     headers["Authorization"] = f"Bearer {settings.intent.token}"
                     logger.info("[Qwen API] 已添加 Intent API Authorization header")
-            elif hasattr(settings, 'qwen3') and hasattr(settings.qwen3, 'token') and api_url == self.api_url:
+            elif hasattr(settings, 'qwen3') and hasattr(settings.qwen3, 'token') and normalized_target == self.api_url.rstrip("/"):
                 if settings.qwen3.token:
                     headers["Authorization"] = f"Bearer {settings.qwen3.token}"
                     logger.info("[Qwen API] 已添加 Qwen3 API Authorization header")
             
-            async with httpx.AsyncClient(timeout=60.0) as client:
+            async with httpx.AsyncClient(timeout=120.0) as client:
                 response = await client.post(
                     target_url,
                     json=data,
@@ -179,9 +194,13 @@ class QwenService:
         except httpx.HTTPStatusError as e:
             logger.error(f"[Qwen API] HTTP 错误 - 状态码: {e.response.status_code}, URL: {target_url}")
             logger.error(f"[Qwen API] HTTP 错误响应: {e.response.text[:500]}")
+            if is_qwen3_target and self._should_fallback(e.response.status_code):
+                return await self._fallback_deepseek(messages)
             raise
         except httpx.RequestError as e:
             logger.error(f"[Qwen API] 请求错误 - URL: {target_url}, 错误: {type(e).__name__}: {str(e)}")
+            if is_qwen3_target:
+                return await self._fallback_deepseek(messages)
             raise
         except Exception as e:
             logger.error(f"[Qwen API] 未知错误 - URL: {target_url}, 模型: {data['model']}, 错误: {type(e).__name__}: {str(e)}")
@@ -220,6 +239,21 @@ class QwenService:
                                     yield content
                             except json.JSONDecodeError:
                                 continue
+        except httpx.HTTPStatusError as e:
+            status_code = e.response.status_code if e.response else 0
+            logger.error(f"Qwen stream HTTP error: {status_code}")
+            if self._should_fallback(status_code):
+                logger.warning("[Qwen API] Stream fallback to DeepSeek")
+                async for chunk in deepseek_service.stream_chat(messages):
+                    yield chunk
+                return
+            raise
+        except httpx.RequestError as e:
+            logger.error(f"Qwen stream request error: {type(e).__name__}: {e}")
+            logger.warning("[Qwen API] Stream fallback to DeepSeek")
+            async for chunk in deepseek_service.stream_chat(messages):
+                yield chunk
+            return
         except Exception as e:
             logger.error(f"Qwen 流式 API 调用失败: {e}")
             raise

+ 2 - 2
shudao-vue-frontend/src/components/CategoryTitle.vue

@@ -32,7 +32,7 @@ const props = defineProps({
 })
 
 const emit = defineEmits(['toggle'])
-const isExpanded = ref(true)
+const isExpanded = ref(false)
 
 const toggleExpand = () => {
   isExpanded.value = !isExpanded.value
@@ -40,7 +40,7 @@ const toggleExpand = () => {
 }
 
 onMounted(() => {
-  emit('toggle', { category: props.category, expanded: true })
+  emit('toggle', { category: props.category, expanded: false })
 })
 </script>
 

Разница между файлами не показана из-за своего большого размера
+ 639 - 95
shudao-vue-frontend/src/views/Chat.vue


+ 128 - 165
shudao-vue-frontend/src/views/ExamWorkshop.vue

@@ -69,13 +69,14 @@
         </div>
         <!-- 考试工坊主界面 -->
         <div v-if="!showExamDetail" class="exam-workshop-card app-container">
-            <!-- 中间主操作区 -->
-            <main class="main-content" style="padding-top: 36px;">
+                <!-- 中间主操作区 -->
+            <main class="main-content" style="padding-top: 36px; position: relative;">
+                <!-- 返回AI问答按钮 -->
+                <button v-if="!showExamDetail" class="return-ai-btn has-before" @click="handleReturnToAI">
+                  返回AI问答
+                </button>
+                
                 <div class="form-group" style="position: relative;">
-                    <!-- 返回AI问答按钮 -->
-                    <button v-if="hideSidebar && !showExamDetail" class="return-ai-btn" @click="handleReturnToAI">
-                      返回AI问答
-                    </button>
                     <label class="form-label">试卷名称</label>
                     <input type="text" class="form-control" v-model="examName" maxlength="32" placeholder="请输入试卷名称..." :disabled="isGenerating">
                     <div class="char-count">{{ examName?.length || 0 }}/32</div>
@@ -83,23 +84,27 @@
 
                 <div class="form-group">
                     <label class="form-label">出题依据内容</label>
-                    <textarea class="form-control" v-model="questionBasis" placeholder="在此输入知识点、章节或培训内容..." :disabled="isGenerating || selectedFile"></textarea>
+                    <textarea class="form-control" v-model="questionBasis" placeholder="在此输入知识点、章节或培训内容..." :disabled="isGenerating || uploadedFiles.length > 0"></textarea>
                     
-                    <div class="ppt-upload-section" @click="!isGenerating && !selectedFile ? triggerFileUpload() : null">
-                        <div class="ppt-upload-content">
-                            <div class="ppt-upload-icon-wrapper">
-                                <el-icon style="font-size: 28px; color: #4b5563;"><UploadFilled /></el-icon>
-                            </div>
-                            <div class="ppt-upload-text-wrapper">
-                                <div class="ppt-upload-title">从PPT生成考题</div>
-                                <div class="ppt-upload-hint">上传培训PPT,智能提取关键内容生成考题(单个文件可上传20M内)</div>
+                    <div class="ppt-upload-section" style="flex-direction: column; align-items: flex-start;" @click="!isGenerating ? triggerFileUpload() : null">
+                        <div style="display: flex; width: 100%; justify-content: space-between; align-items: center;">
+                            <div class="ppt-upload-content">
+                                <div class="ppt-upload-icon-wrapper">
+                                    <el-icon style="font-size: 28px; color: #4b5563;"><UploadFilled /></el-icon>
+                                </div>
+                                <div class="ppt-upload-text-wrapper">
+                                    <div class="ppt-upload-title">从PPT生成考题</div>
+                                    <div class="ppt-upload-hint">上传培训PPT,智能提取关键内容生成考题(支持多文件,单文件20M内)</div>
+                                </div>
                             </div>
+                            <el-icon class="ppt-arrow"><ArrowRight /></el-icon>
                         </div>
-                        <el-icon class="ppt-arrow"><ArrowRight /></el-icon>
                         
-                        <div v-if="selectedFile" class="file-status-badge" @click.stop>
-                          <span class="file-name truncate">已上传: {{ selectedFile.name }}</span>
-                          <span @click.stop="removeSelectedFile" class="remove-btn">×</span>
+                        <div v-if="uploadedFiles.length > 0" class="files-list" @click.stop style="width: 100%; display: flex; flex-wrap: wrap; gap: 8px;">
+                          <div v-for="(file, index) in uploadedFiles" :key="index" class="file-status-badge">
+                            <span class="file-name truncate">已上传: {{ file.name }}</span>
+                            <span @click.stop="removeSelectedFile(index)" class="remove-btn">×</span>
+                          </div>
                         </div>
                     </div>
                 </div>
@@ -209,12 +214,11 @@
           <!-- 详情页头部 -->
           <div class="detail-header">
             <div class="header-left">
-              <button class="back-btn" @click="backToConfig" :disabled="isGenerating">
-                <span class="back-arrow">←</span>
+            </div>
+            <div class="header-right" style="display: flex; align-items: center; gap: 12px;">
+              <button class="return-ai-btn has-before" style="position: static;" @click="backToConfig" :disabled="isGenerating">
                 返回修改
               </button>
-            </div>
-            <div class="header-right">
               <!-- <button class="save-btn" @click="saveExam" :disabled="isGenerating">
                 <img :src="saveIcon" alt="保存试卷" class="save-icon" />
               </button> -->
@@ -565,6 +569,7 @@
       ref="fileInput"
       type="file"
       accept=".ppt,.pptx"
+      multiple
       style="display: none"
       @change="handleFileSelect"
     />
@@ -583,6 +588,7 @@
 
 <script setup>
 import { ref, computed, onMounted, onUnmounted, reactive, watch, defineProps, defineEmits } from "vue";
+import { useRoute, useRouter } from "vue-router";
 import Sidebar from "@/components/Sidebar.vue";
 import DeleteConfirmModal from "@/components/DeleteConfirmModal.vue";
 import { UploadFilled, ArrowRight, Delete, MagicStick, Loading } from '@element-plus/icons-vue';
@@ -596,6 +602,9 @@ const props = defineProps({
 
 const emit = defineEmits(['return-to-ai']);
 
+const route = useRoute();
+const router = useRouter();
+
 const handleReturnToAI = () => {
   emit('return-to-ai');
 };
@@ -655,7 +664,7 @@ const editModalData = ref({
 
 // PPT文件上传相关
 const fileInput = ref(null);
-const selectedFile = ref(null);
+const uploadedFiles = ref([]);
 const isUploadingFile = ref(false);
 const fileContent = ref(''); // 存储文件内容
 const pptContentDescription = ref(''); // 存储用户输入的PPT内容描述
@@ -693,22 +702,6 @@ const currentExam = ref(null);
 const historyData = ref([])
 const historyTotal = ref(0) // 历史记录总数
 
-
-const isExamWorkshopConversation = (conversation = {}) => {
-  const content = String(conversation.content || '')
-  const title = String(conversation.title || '')
-  const examName = String(conversation.exam_name || '')
-
-  return (
-    Number(conversation.business_type) === 3 ||
-    !!examName.trim() ||
-    title.includes('技术考核') ||
-    content.includes('请根据以下要求直接生成一份完整试卷') ||
-    content.includes('"singleChoice"') ||
-    content.includes('"totalQuestions"')
-  )
-}
-
 // 获取历史记录列表
 const getHistoryRecordList = async () => {
   try {
@@ -717,9 +710,8 @@ const getHistoryRecordList = async () => {
     const startTime = performance.now()
     
     const response = await apis.getHistoryRecord({ 
-      // ===== 已删除:user_id - 后端从token解析 =====
-      ai_conversation_id: 0, // 0表示获取对话列表
-      business_type: 3 // 考试工坊类型
+      ai_conversation_id: 0,
+      business_type: 3
     })
     
     const endTime = performance.now()
@@ -727,46 +719,19 @@ const getHistoryRecordList = async () => {
     console.log('📋 考试工坊历史记录列表响应:', response)
     
     if (response.statusCode === 200) {
-      const directConversations = Array.isArray(response.data) ? response.data : []
-      let conversations = [...directConversations]
-
-      const fallbackResponse = await apis.getHistoryRecord({
-        ai_conversation_id: 0
-      })
-
-      if (fallbackResponse.statusCode === 200 && Array.isArray(fallbackResponse.data)) {
-        const inferredExamConversations = fallbackResponse.data.filter(isExamWorkshopConversation)
-        const conversationMap = new Map()
-
-        directConversations.concat(inferredExamConversations).forEach((conversation) => {
-          if (!conversation?.id) return
-          conversationMap.set(conversation.id, conversation)
-        })
-
-        conversations = Array.from(conversationMap.values()).sort((a, b) => {
-          return Number(b.updated_at || 0) - Number(a.updated_at || 0)
-        })
-      }
-
-      // 设置历史记录总数
-      historyTotal.value = conversations.length
-      
-      // 转换后端数据为前端格式
-      historyData.value = conversations.map(conversation => ({
+      historyTotal.value = response.total || 0
+      historyData.value = response.data.map(conversation => ({
         id: conversation.id,
-        title: generateConversationTitle(conversation.exam_name || conversation.title || conversation.content),
+        title: generateConversationTitle(conversation.exam_name),
         time: formatTime(conversation.updated_at),
         businessType: conversation.business_type,
         isActive: false,
-        // 保存原始数据用于后续查询
         rawData: conversation
       }))
       console.log(`✅ 考试工坊历史记录列表已设置: ${historyData.value.length}条记录,总数: ${historyTotal.value}`)
     } else {
       console.error('❌ 获取考试工坊历史记录列表失败:', response.statusCode)
     }
-  } catch (error) {
-    console.error('❌ 获取考试工坊历史记录列表失败:', error)
   } finally {
     isLoadingHistory.value = false
   }
@@ -902,7 +867,7 @@ const confirmDeleteHistory = async () => {
     
     if (response.statusCode === 200) {
       // 删除成功,从列表中移除
-      historyData.value.splice(index, 1)
+      removeExamWorkshopHistory(historyItem.id)
       
       // 如果删除的是当前激活的历史记录,需要清空界面并调用新建任务
       if (historyItem.isActive) {
@@ -980,7 +945,7 @@ const createNewChat = async () => {
   isRefreshing.value = {};
   
   // 清理文件
-  selectedFile.value = null;
+  uploadedFiles.value = [];
   pptContentDescription.value = '';
   
   // 清除所有历史记录的选中状态
@@ -1000,25 +965,20 @@ const handleHistoryItem = async (historyItem) => {
   isLoadingHistoryItem.value = true;
   
   try {
-    // 设置当前点击的历史记录为激活状态
     historyData.value.forEach((item) => {
       item.isActive = item.id === historyItem.id;
     });
     
-    // 获取该历史记录的详细内容
     const response = await apis.getHistoryRecord({ 
-      // ===== 已删除:user_id - 后端从token解析 =====
-      ai_conversation_id: historyItem.id, // 使用历史记录的ID作为ai_conversation_id
-      business_type: 3 // 考试工坊类型
+      ai_conversation_id: historyItem.id,
+      business_type: 3
     });
     console.log(response.data)
     if (response.statusCode === 200 && response.data && response.data.length > 0) {
-      // 获取最新的试卷数据(取最新的AI消息)
-      const latestRecord = response.data[response.data.length - 1]; // 获取最新记录
+      const latestRecord = response.data[response.data.length - 1];
       console.log('获取到的试卷数据:', latestRecord);
       console.log('试卷数据结构:', JSON.stringify(latestRecord, null, 2));
       currentTime.value = formatTime(latestRecord.created_at)
-      // 解析试卷数据并恢复
       if (latestRecord && latestRecord.content) {
         try {
           const examData = extractExamDataFromContent(latestRecord.content);
@@ -1026,24 +986,20 @@ const handleHistoryItem = async (historyItem) => {
           showExamDetail.value = true;
         } catch (error) {
           console.error('解析试卷数据失败:', error);
-          // 如果解析失败,显示默认详情页
           showExamDetail.value = true;
           currentTime.value = historyItem.time;
         }
       } else {
-        // 如果没有内容,显示默认详情页
         showExamDetail.value = true;
         currentTime.value = historyItem.time;
       }
     } else {
       console.error('获取历史记录详情失败:', response);
-      // 显示默认详情页
       showExamDetail.value = true;
       currentTime.value = historyItem.time;
     }
   } catch (error) {
     console.error('获取历史记录详情失败:', error);
-    // 显示默认详情页
     showExamDetail.value = true;
     currentTime.value = historyItem.time;
   } finally {
@@ -1183,7 +1139,7 @@ const generateExam = async () => {
     await generateAIExam();
   } else {
     // PPT生成方式
-    if (!selectedFile.value) {
+    if (uploadedFiles.value.length === 0) {
       ElMessage.warning("请先上传PPT文件");
       return;
     }
@@ -1229,17 +1185,14 @@ const generatePPTExam = async () => {
       showExamDetail.value = true;
       ElMessage.success("PPT试卷生成完成!");
       
-      // AI回复完成后,获取最新的历史记录
       await getHistoryRecordList();
       
-      // 如果是新对话,将最新的历史记录设为激活状态
       if (ai_conversation_id.value > 0) {
         historyData.value.forEach((item) => {
           item.isActive = item.id === ai_conversation_id.value;
         });
         console.log('设置最新历史记录为激活状态,conversationId:', ai_conversation_id.value);
       } else {
-        // 如果没有对话ID,选中第一条记录
         selectLatestHistoryRecord();
       }
     } else {
@@ -1296,17 +1249,14 @@ const generateAIExam = async () => {
       showExamDetail.value = true;
       ElMessage.success("AI试卷生成完成!");
       
-      // AI回复完成后,获取最新的历史记录
       await getHistoryRecordList();
       
-      // 如果是新对话,将最新的历史记录设为激活状态
       if (ai_conversation_id.value > 0) {
         historyData.value.forEach((item) => {
           item.isActive = item.id === ai_conversation_id.value;
         });
         console.log('设置最新历史记录为激活状态,conversationId:', ai_conversation_id.value);
       } else {
-        // 如果没有对话ID,选中第一条记录
         selectLatestHistoryRecord();
       }
     } else {
@@ -1334,6 +1284,9 @@ const fetchExamPrompt = async (mode = 'ai') => {
     scorePerQuestion: Number(type.scorePerQuestion) || 0,
   }));
 
+  const pptContents = uploadedFiles.value.map(file => file.content).join('\n\n');
+  const finalContentBasis = pptContents || questionBasis.value || '';
+
   const payload = {
     mode,
     client: 'pc',
@@ -1341,7 +1294,7 @@ const fetchExamPrompt = async (mode = 'ai') => {
     examTitle: examName.value,
     totalScore: totalScore.value,
     questionTypes: normalizedQuestionTypes,
-    pptContent: selectedFile.value?.content || questionBasis.value || ''
+    pptContent: finalContentBasis
   };
 
   try {
@@ -2893,16 +2846,11 @@ const saveExam = async () => {
     
     console.log('准备保存的试卷数据:', examData);
     
-    // 调用后端保存接口
     const response = await apis.saveExam(examData);
     
     if (response.statusCode === 200) {
       ElMessage.success("试卷保存成功!");
-      
-      // 更新历史记录
       updateHistoryData(examData);
-      
-      // 可以在这里刷新历史记录列表
       console.log('试卷已保存到历史记录');
     } else {
       throw new Error('保存失败');
@@ -2987,17 +2935,15 @@ const prepareExamDataForSave = () => {
 // 更新历史记录数据
 const updateHistoryData = (examData) => {
   const newHistoryItem = {
-    id: Date.now(), // 使用时间戳作为临时ID
+    id: Date.now(),
     title: examData.exam_name,
     time: examData.generation_time,
     isActive: false,
-    examData: examData // 保存完整的试卷数据
+    examData: examData
   };
   
-  // 添加到历史记录开头
   historyData.value.unshift(newHistoryItem);
   
-  // 限制历史记录数量(比如最多保存20条)
   if (historyData.value.length > 20) {
     historyData.value = historyData.value.slice(0, 20);
   }
@@ -3449,65 +3395,68 @@ PPT文件处理失败,请手动描述PPT的主要内容、关键知识点、
 
 // 处理文件选择
 const handleFileSelect = async (event) => {
-  const file = event.target.files[0]
-  if (!file) return
+  const files = Array.from(event.target.files)
+  if (!files || files.length === 0) return
   
-  try {
-    // 验证文件
-    const fileExtension = validateFile(file)
-    
-    isUploadingFile.value = true
-    console.log('开始读取文件内容:', file.name)
-    
-    // 处理PPT文档
-    const extractedContent = await readPPTFile(file)
-    
-    // 创建文件信息对象
-    selectedFile.value = {
-      file,
-      name: file.name,
-      size: file.size,
-      type: fileExtension,
-      icon: getFileIcon(fileExtension),
-      content: extractedContent // 存储提取的内容
+  isUploadingFile.value = true
+  let successCount = 0;
+  
+  for (const file of files) {
+    try {
+      // 验证文件
+      const fileExtension = validateFile(file)
+      console.log('开始读取文件内容:', file.name)
+      
+      // 处理PPT文档
+      const extractedContent = await readPPTFile(file)
+      
+      // 创建文件信息对象
+      uploadedFiles.value.push({
+        file,
+        name: file.name,
+        size: file.size,
+        type: fileExtension,
+        icon: getFileIcon(fileExtension),
+        content: extractedContent // 存储提取的内容
+      })
+      successCount++;
+      
+      // 如果是第一个上传的文件,且当前试卷名称还是默认状态或为空,使用该文件名作为试卷名称
+      if (uploadedFiles.value.length === 1 && (!examName.value || examName.value.includes('工程施工技术考核'))) {
+        const fileNameWithoutExt = file.name.replace(/\.(ppt|pptx)$/i, '')
+        examName.value = `${fileNameWithoutExt}考试试卷`
+      }
+      
+    } catch (error) {
+      console.error(`文件 ${file.name} 读取失败:`, error)
+      ElMessage.error(`${file.name}读取失败: ${error.message || '请重试'}`)
     }
-    
-    // 使用文件名作为试卷名称(去掉扩展名)
-    const fileNameWithoutExt = file.name.replace(/\.(ppt|pptx)$/i, '')
-    examName.value = `${fileNameWithoutExt}考试试卷`
-    
-    // 显示提取的内容长度
-    const contentLength = extractedContent.length
-    console.log('文件内容提取完成,字符数:', contentLength)
-    ElMessage.success(`PPT文件读取成功,提取了${contentLength}个字符的内容`)
-    
-  } catch (error) {
-    console.error('文件读取失败:', error)
-    ElMessage.error(error.message || '文件读取失败,请重试')
-  } finally {
-    isUploadingFile.value = false
-    event.target.value = ''
   }
+  
+  if (successCount > 0) {
+    ElMessage.success(`成功读取了 ${successCount} 个文件`)
+  }
+  
+  isUploadingFile.value = false
+  event.target.value = ''
 }
 
 // 删除选中的文件
-const removeSelectedFile = () => {
-  if (selectedFile.value) {
-    selectedFile.value = null
-    // 清空PPT内容描述
-    pptContentDescription.value = ''
-    // 重置试卷名称为默认值
-    const projectTypeName = projectTypes[selectedProjectType.value].name
-    examName.value = `${projectTypeName}工程施工技术考核`
+const removeSelectedFile = (index) => {
+  if (index >= 0 && index < uploadedFiles.value.length) {
+    uploadedFiles.value.splice(index, 1)
+    
+    // 如果全部删除了,重置相关状态
+    if (uploadedFiles.value.length === 0) {
+      pptContentDescription.value = ''
+      const projectTypeName = projectTypes[selectedProjectType.value]?.name || '桥梁'
+      examName.value = `${projectTypeName}工程施工技术考核`
+    }
   }
 }
 
 // 触发文件上传
 const triggerFileUpload = () => {
-  if (selectedFile.value) {
-    ElMessage.warning('只能上传一个文件,请先删除当前文件')
-    return
-  }
   fileInput.value?.click()
 }
 
@@ -3538,6 +3487,15 @@ onMounted(async () => {
   // 获取历史记录列表
   await getHistoryRecordList()
   
+  // 检查URL参数是否有historyId需要加载
+  const historyId = route.query.historyId
+  if (historyId) {
+    const targetItem = historyData.value.find(item => String(item.id) === String(historyId))
+    if (targetItem) {
+      await handleHistoryItem(targetItem)
+    }
+  }
+  
   // 添加全局点击事件监听器
   document.addEventListener('click', handleClickOutside);
   
@@ -3731,11 +3689,11 @@ onUnmounted(() => {
 /* 工作头部 */
 .work-header {
   background: transparent;
-  padding: 30px 0px 0px 18px;
+  padding: 40px 0px 0px 18px;
   
   h2 {
     margin: 0;
-    font-size: 20px;
+    font-size: 25px;
     font-weight: 600;
     color: #2c3e50;
   }
@@ -3910,19 +3868,17 @@ onUnmounted(() => {
     }
 
     .file-status-badge {
-        position: absolute;
-        bottom: -40px;
-        left: 0;
         background: #ebf3ff;
         color: var(--primary-color);
-        padding: 8px 16px;
+        padding: 5px 12px;
         border-radius: 8px;
-        font-size: 13px;
+        font-size: 10px;
         display: flex;
         align-items: center;
-        gap: 12px;
+        gap: 8px;
         border: 1px solid rgba(13, 110, 253, 0.1);
-        max-width: 300px;
+        max-width: 100%;
+        margin-top: 12px;
     }
 
     .file-name {
@@ -5807,8 +5763,8 @@ onUnmounted(() => {
 
 .return-ai-btn {
   position: absolute;
-  top: -15px;
-  right: 0;
+  top: 10px;
+  right: 20px;
   z-index: 100;
   background: white;
   border: 1px solid rgba(0, 0, 0, 0.06);
@@ -5823,15 +5779,22 @@ onUnmounted(() => {
   align-items: center;
   gap: 5px;
   transition: all 0.3s ease;
+  height: 36px;
+  box-sizing: border-box;
+}
+
+.return-ai-btn:disabled {
+  opacity: 0.5;
+  cursor: not-allowed;
 }
 
-.return-ai-btn:hover {
+.return-ai-btn:hover:not(:disabled) {
   box-shadow: 0 8px 24px rgba(13, 110, 253, 0.12);
   color: #0d6efd;
   border-color: rgba(13, 110, 253, 0.2);
 }
 
-.return-ai-btn::before {
+.return-ai-btn.has-before::before {
   content: '←';
   font-size: 16px;
   font-weight: bold;

+ 53 - 9
shudao-vue-frontend/src/views/mobile/m-Chat.vue

@@ -1389,7 +1389,7 @@ const getConversationMessages = async (conversationId) => {
               .map(r => r.category)
             
             categories.forEach(category => {
-              categoryExpandStates.value[index][category] = true
+              categoryExpandStates.value[index][category] = false
             })
           }
         
@@ -1496,6 +1496,15 @@ const handleHistoryItem = async (historyItem) => {
   
   // 关闭历史记录弹窗(移动端特有)
   showHistory.value = false
+
+  // 根据业务类型跳转到对应模块
+  const bType = Number(historyItem.businessType)
+  if (bType === 3) {
+    router.push({ path: '/mobile/exam-workshop', query: { historyId: historyItem.id } })
+    return
+  } else if (bType !== 0) {
+    // 其他非0业务类型也可以在这里做重定向或提示,但目前仅处理考试工坊
+  }
   
   // 清空当前消息
     chatMessages.value = []
@@ -2273,16 +2282,49 @@ const handleCategoryToggle = (messageIndex, data) => {
   if (!categoryExpandStates.value[messageIndex]) {
     categoryExpandStates.value[messageIndex] = {}
   }
-  categoryExpandStates.value[messageIndex][data.category] = data.expanded
+  const matchedCategory = findCategoryStateKey(messageIndex, data.category)
+  categoryExpandStates.value[messageIndex][matchedCategory || data.category] = data.expanded
+}
+
+const normalizeCategoryName = (category) => {
+  return typeof category === 'string' ? category.trim() : ''
+}
+
+const findCategoryStateKey = (messageIndex, category) => {
+  const stateMap = categoryExpandStates.value[messageIndex]
+  if (!stateMap) return ''
+
+  const normalizedCategory = normalizeCategoryName(category)
+  if (!normalizedCategory) return ''
+
+  if (Object.prototype.hasOwnProperty.call(stateMap, category)) {
+    return category
+  }
+
+  const keys = Object.keys(stateMap)
+  const exactKey = keys.find(key => normalizeCategoryName(key) === normalizedCategory)
+  if (exactKey) return exactKey
+
+  return keys.find(key => {
+    const normalizedKey = normalizeCategoryName(key)
+    return normalizedKey && (
+      normalizedKey.includes(normalizedCategory) ||
+      normalizedCategory.includes(normalizedKey)
+    )
+  }) || ''
 }
 
 const isCategoryExpanded = (messageIndex, category) => {
   if (!category) return true
   if (!categoryExpandStates.value[messageIndex]) {
     categoryExpandStates.value[messageIndex] = {}
-    return true
+    return false
   }
-  return categoryExpandStates.value[messageIndex][category] !== false
+
+  const matchedCategory = findCategoryStateKey(messageIndex, category)
+  if (!matchedCategory) return false
+
+  return categoryExpandStates.value[messageIndex][matchedCategory] === true
 }
 
 // 检查reports数组是否只包含分类标题,没有实际报告
@@ -2572,7 +2614,7 @@ const handleSSEMessage = (data, aiMessageIndex) => {
       if (!categoryExpandStates.value[aiMessageIndex]) {
         categoryExpandStates.value[aiMessageIndex] = {}
       }
-      categoryExpandStates.value[aiMessageIndex][data.category] = true
+      categoryExpandStates.value[aiMessageIndex][data.category] = false
       
       // 保存当前分类名称,用于后续报告匹配
       aiMessage.currentCategory = data.category
@@ -2596,7 +2638,7 @@ const handleSSEMessage = (data, aiMessageIndex) => {
         similarity: data.similarity,
         metadata: {
           ...data.metadata,
-          _displayCategory: aiMessage.currentCategory // 存储当前显示的分类名
+          _displayCategory: data.metadata?.primary_category || aiMessage.currentCategory // 存储当前显示的分类名
         },
         report: {
           display_name: '',
@@ -2636,7 +2678,9 @@ const handleSSEMessage = (data, aiMessageIndex) => {
       
       let targetReport
       if (idx !== undefined) {
-        const displayCategory = aiMessage.reports[idx].metadata?._displayCategory
+        const displayCategory = reportData.metadata?.primary_category ||
+          aiMessage.reports[idx].metadata?._displayCategory ||
+          aiMessage.currentCategory
         const fullSummary = reportData.report?.summary || ''
         const fullAnalysis = reportData.report?.analysis || ''
         const fullClauses = reportData.report?.clauses || ''
@@ -2654,7 +2698,7 @@ const handleSSEMessage = (data, aiMessageIndex) => {
           status: 'completed',
           metadata: {
             ...reportData.metadata, // 保留所有metadata字段
-            _displayCategory: displayCategory || aiMessage.currentCategory
+            _displayCategory: displayCategory
           },
           _fullContent: {
             display_name: fullDisplayName,
@@ -2682,7 +2726,7 @@ const handleSSEMessage = (data, aiMessageIndex) => {
           status: 'completed',
           metadata: {
             ...reportData.metadata, // 保留所有metadata字段
-            _displayCategory: aiMessage.currentCategory
+            _displayCategory: reportData.metadata?.primary_category || aiMessage.currentCategory
           },
           _fullContent: {
             display_name: fullDisplayName,

+ 20 - 63
shudao-vue-frontend/src/views/mobile/m-ExamWorkshop.vue

@@ -481,7 +481,7 @@
 </template>
 
 <script setup>
-import { useRouter } from 'vue-router'
+import { useRouter, useRoute } from 'vue-router'
 import MobileHeader from '@/components/MobileHeader.vue'
 import MobileHistoryDrawer from '@/components/MobileHistoryDrawer.vue'
 import { ref, onMounted, onUnmounted, watch, computed } from 'vue'
@@ -492,6 +492,7 @@ import { initNativeNavForSubPage } from '@/utils/nativeBridge.js'
 // import { getUserId } from '@/utils/userManager.js'
 
 const router = useRouter()
+const route = useRoute()
 
 const goBack = () => {
   router.go(-1)
@@ -776,25 +777,21 @@ const handleHistoryItem = async (historyItem) => {
 const deleteHistoryItem = async (historyItem, index) => {
   try {
     console.log('开始删除移动端历史记录:', historyItem)
-    
+
     const response = await apis.deleteHistoryRecord({
-      // ===== 已删除:user_id - 后端从token解析 =====
       ai_conversation_id: historyItem.id
     })
     
     if (response.statusCode === 200) {
-      // 从本地数据中移除
       historyData.value.splice(index, 1)
       historyTotal.value = Math.max(0, historyTotal.value - 1)
       
-      // 如果删除的是当前激活的历史记录,执行新建任务
       if (historyItem.isActive) {
         console.log('删除激活的历史记录,执行新建任务')
         createNewTask()
       }
       
       console.log('✅ 移动端历史记录删除成功')
-      // 轻提示
       showToast('删除成功')
     } else {
       console.error('❌ 删除移动端历史记录失败:', response)
@@ -840,22 +837,6 @@ const formatHistoryTime = (timestamp) => {
   return `${month}月${day}日 ${time}`
 }
 
-// 获取历史记录列表
-const isExamWorkshopConversation = (conversation = {}) => {
-  const content = String(conversation.content || '')
-  const title = String(conversation.title || '')
-  const examName = String(conversation.exam_name || '')
-
-  return (
-    Number(conversation.business_type) === 3 ||
-    !!examName.trim() ||
-    title.includes('技术考核') ||
-    content.includes('请根据以下要求直接生成一份完整试卷') ||
-    content.includes('"singleChoice"') ||
-    content.includes('"totalQuestions"')
-  )
-}
-
 const getHistoryRecordList = async () => {
   try {
     console.log('📋 开始获取移动端考试工坊历史记录列表...')
@@ -863,9 +844,8 @@ const getHistoryRecordList = async () => {
     const startTime = performance.now()
     
     const response = await apis.getHistoryRecord({ 
-      // ===== 已删除:user_id - 后端从token解析 =====
-      ai_conversation_id: 0, // 0表示获取对话列表
-      business_type: 3 // 考试工坊类型
+      ai_conversation_id: 0,
+      business_type: 3
     })
     
     const endTime = performance.now()
@@ -873,41 +853,16 @@ const getHistoryRecordList = async () => {
     console.log('📋 移动端历史记录列表响应:', response)
     
     if (response.statusCode === 200) {
-      const directConversations = Array.isArray(response.data) ? response.data : []
-      let conversations = [...directConversations]
-
-      const fallbackResponse = await apis.getHistoryRecord({
-        ai_conversation_id: 0
-      })
-
-      if (fallbackResponse.statusCode === 200 && Array.isArray(fallbackResponse.data)) {
-        const inferredExamConversations = fallbackResponse.data.filter(isExamWorkshopConversation)
-        const conversationMap = new Map()
-
-        directConversations.concat(inferredExamConversations).forEach((conversation) => {
-          if (!conversation?.id) return
-          conversationMap.set(conversation.id, conversation)
-        })
-
-        conversations = Array.from(conversationMap.values()).sort((a, b) => {
-          return Number(b.updated_at || 0) - Number(a.updated_at || 0)
-        })
-      }
-
-      // 设置历史记录总数
-      historyTotal.value = conversations.length
+      historyTotal.value = response.total || 0
       
-      // 转换后端数据为前端格式
-      historyData.value = conversations.map(conversation => ({
+      historyData.value = response.data.map(conversation => ({
         id: conversation.id,
-        title: generateConversationTitle(conversation.exam_name || conversation.title || conversation.content),
+        title: generateConversationTitle(conversation.content),
         time: formatHistoryTime(conversation.updated_at),
         businessType: conversation.business_type,
         isActive: false,
-        // 保存原始数据用于后续查询
         rawData: conversation
       }))
-      // 高亮当前对话
       if (ai_conversation_id.value) {
         historyData.value.forEach(item => { item.isActive = item.id === ai_conversation_id.value })
       }
@@ -915,8 +870,6 @@ const getHistoryRecordList = async () => {
     } else {
       console.error('❌ 获取移动端历史记录列表失败:', response.statusCode)
     }
-  } catch (error) {
-    console.error('❌ 获取移动端历史记录列表失败:', error)
   } finally {
     isLoadingHistory.value = false
   }
@@ -1095,14 +1048,6 @@ const generateExam = async () => {
       // 显示考试详情页
       showExamDetail.value = true;
 
-      // 刷新历史记录,确保新生成的试卷详情可被立即查看
-      await getHistoryRecordList();
-      if (ai_conversation_id.value > 0) {
-        historyData.value.forEach((item) => {
-          item.isActive = item.id === ai_conversation_id.value;
-        });
-      }
-
       console.log('✅ 移动端试卷生成完成!');
 
     } else {
@@ -1361,6 +1306,7 @@ const restoreExamFromHistory = (examData) => {
 
   examName.value = normalizedExam.title || examName.value
   totalScore.value = normalizedExam.totalScore || totalScore.value
+  currentTime.value = exam.generation_time || currentTime.value
 
   questionTypes.value = [
     { name: "单选题", scorePerQuestion: normalizedExam.singleChoice.scorePerQuestion, questionCount: normalizedExam.singleChoice.count, romanNumeral: "一" },
@@ -2270,6 +2216,17 @@ onMounted(async () => {
     
     // 初始化原生导航栏(子页面模式:返回按钮执行路由后退)
     initNativeNavForSubPage(() => router.back())
+    
+    // 检查URL参数是否有historyId需要加载
+    const historyId = route.query.historyId
+    if (historyId) {
+      // 需要先加载历史记录列表才能找到对应项
+      await getHistoryRecordList()
+      const targetItem = historyData.value.find(item => String(item.id) === String(historyId))
+      if (targetItem) {
+        await handleHistoryItem(targetItem)
+      }
+    }
   } catch (error) {
     console.error('❌ 移动端考试工坊页面初始化失败:', error)
   }

Некоторые файлы не были показаны из-за большого количества измененных файлов