فهرست منبع

修复后端JSON提取BUG

FanHong 3 هفته پیش
والد
کامیت
411840d09b
2فایلهای تغییر یافته به همراه478 افزوده شده و 29 حذف شده
  1. 356 29
      shudao-chat-py/routers/chat.py
  2. 122 0
      shudao-chat-py/tests/test_exam_response_sanitizer.py

+ 356 - 29
shudao-chat-py/routers/chat.py

@@ -39,6 +39,306 @@ def _build_conversation_title(conversation: AIConversation) -> str:
     return _build_conversation_preview(conversation.content or "", limit=30)
 
 
+def _extract_json_object_from_index(source: str, start_idx: int) -> str:
+    if start_idx < 0 or start_idx >= len(source) or source[start_idx] != "{":
+        return ""
+
+    depth = 0
+    in_string = False
+    escaped = False
+
+    for idx in range(start_idx, len(source)):
+        ch = source[idx]
+
+        if escaped:
+            escaped = False
+            continue
+
+        if in_string:
+            if ch == "\\":
+                escaped = True
+            elif ch == '"':
+                in_string = False
+            continue
+
+        if ch == '"':
+            in_string = True
+            continue
+
+        if ch == "{":
+            depth += 1
+        elif ch == "}":
+            depth -= 1
+            if depth == 0:
+                return source[start_idx: idx + 1]
+
+    return ""
+
+
+def _extract_balanced_json_objects(text: str) -> list[str]:
+    source = (text or "").strip()
+    if not source:
+        return []
+
+    objects = []
+    seen = set()
+    for idx, ch in enumerate(source):
+        if ch != "{":
+            continue
+        candidate = _extract_json_object_from_index(source, idx)
+        if candidate and candidate not in seen:
+            objects.append(candidate)
+            seen.add(candidate)
+
+    return objects
+
+
+def _extract_trailing_json_candidates(text: str) -> list[str]:
+    source = (text or "").strip()
+    if not source:
+        return []
+
+    candidates = []
+    seen = set()
+    line_start_indexes = [
+        match.start()
+        for match in re.finditer(r"(?m)^[ \t]*\{", source)
+    ]
+
+    for start_idx in reversed(line_start_indexes):
+        candidate = source[start_idx:].strip()
+        if candidate and candidate not in seen:
+            candidates.append(candidate)
+            seen.add(candidate)
+
+    return candidates
+
+
+def _looks_like_exam_payload(payload: object) -> bool:
+    if not isinstance(payload, dict):
+        return False
+
+    questions = payload.get("questions")
+    return any(
+        key in payload
+        for key in (
+            "singleChoice",
+            "single_choice",
+            "单选题",
+            "judge",
+            "判断题",
+            "multiple",
+            "multiple_choice",
+            "multipleChoice",
+            "多选题",
+            "short",
+            "short_answer",
+            "shortAnswer",
+            "简答题",
+        )
+    ) or (
+        isinstance(questions, dict)
+        and any(
+            key in questions
+            for key in (
+                "singleChoice",
+                "single_choice",
+                "单选题",
+                "judge",
+                "判断题",
+                "multiple",
+                "multiple_choice",
+                "multipleChoice",
+                "多选题",
+                "short",
+                "short_answer",
+                "shortAnswer",
+                "简答题",
+            )
+        )
+    )
+
+
+def _score_exam_payload_candidate(payload: object) -> int:
+    if not isinstance(payload, dict):
+        return 0
+
+    score = 0
+    questions = payload.get("questions") if isinstance(
+        payload.get("questions"), dict) else {}
+
+    strong_keys = (
+        "singleChoice",
+        "single_choice",
+        "单选题",
+        "judge",
+        "判断题",
+        "multiple",
+        "multiple_choice",
+        "multipleChoice",
+        "多选题",
+        "short",
+        "short_answer",
+        "shortAnswer",
+        "简答题",
+    )
+    weak_keys = (
+        "title",
+        "exam_name",
+        "examTitle",
+        "试卷标题",
+        "总分",
+        "totalScore",
+        "totalQuestions",
+    )
+
+    score += sum(10 for key in strong_keys if key in payload)
+    score += sum(8 for key in strong_keys if key in questions)
+    score += sum(2 for key in weak_keys if key in payload)
+
+    section_candidates = []
+    for _, value in payload.items():
+        if isinstance(value, dict):
+            section_candidates.append(value)
+    section_candidates.extend(
+        value for value in questions.values() if isinstance(value, dict))
+
+    for section in section_candidates:
+        if "questions" in section and isinstance(section.get("questions"), list):
+            score += 6
+            question_list = section.get("questions") or []
+            if question_list and isinstance(question_list[0], dict):
+                first_question = question_list[0]
+                if any(k in first_question for k in ("text", "question_text", "question", "title", "content", "题干", "题目")):
+                    score += 4
+                if "options" in first_question:
+                    score += 3
+                if any(k in first_question for k in ("answer", "answers", "correct_answer", "correct_answers", "答案", "正确答案")):
+                    score += 3
+                if any(k in first_question for k in ("analysis", "explanation", "解析")):
+                    score += 2
+        if any(k in section for k in ("count", "question_count", "数量")):
+            score += 2
+        if any(k in section for k in ("scorePerQuestion", "score_per_question", "每题分值")):
+            score += 1
+
+    return score
+
+
+def _escape_inner_quotes_in_json(text: str) -> str:
+    chars = []
+    in_string = False
+    escaped = False
+
+    for idx, ch in enumerate(text):
+        if not in_string:
+            chars.append(ch)
+            if ch == '"':
+                in_string = True
+                escaped = False
+            continue
+
+        if escaped:
+            chars.append(ch)
+            escaped = False
+            continue
+
+        if ch == "\\":
+            chars.append(ch)
+            escaped = True
+            continue
+
+        if ch == '"':
+            next_non_space = ""
+            for next_idx in range(idx + 1, len(text)):
+                if not text[next_idx].isspace():
+                    next_non_space = text[next_idx]
+                    break
+
+            if next_non_space in {",", "}", "]", ":"}:
+                chars.append(ch)
+                in_string = False
+            else:
+                chars.append('\\"')
+            continue
+
+        chars.append(ch)
+
+    return "".join(chars)
+
+
+def _try_parse_exam_json(candidate: str) -> Optional[dict]:
+    text = (candidate or "").strip()
+    if not text:
+        return None
+
+    text = (
+        text.replace("\ufeff", "")
+        .replace("```json", "")
+        .replace("```JSON", "")
+        .replace("```", "")
+        .replace("“", '"')
+        .replace("”", '"')
+    ).strip()
+
+    try:
+        parsed = json.loads(text)
+    except Exception:
+        repaired_text = _escape_inner_quotes_in_json(text)
+        repaired_text = re.sub(r",\s*([}\]])", r"\1", repaired_text)
+        try:
+            parsed = json.loads(repaired_text)
+        except Exception:
+            return None
+
+    return parsed if _looks_like_exam_payload(parsed) else None
+
+
+def _sanitize_exam_response(raw_response: str) -> str:
+    """考试工坊只向前端/数据库透传可 JSON.parse 的试卷 JSON。"""
+    raw_text = (raw_response or "").strip()
+    if not raw_text:
+        return ""
+
+    _, answer = split_thinking_and_answer(raw_text)
+    for candidate in (answer, raw_text):
+        parsed = _try_parse_exam_json(candidate)
+        if parsed:
+            return json.dumps(parsed, ensure_ascii=False)
+
+    parsed_candidates = []
+    for candidate in _extract_balanced_json_objects(raw_text):
+        parsed = _try_parse_exam_json(candidate)
+        if parsed:
+            parsed_candidates.append((parsed, candidate))
+
+    for candidate in _extract_trailing_json_candidates(raw_text):
+        parsed = _try_parse_exam_json(candidate)
+        if parsed:
+            parsed_candidates.append((parsed, candidate))
+
+    if parsed_candidates:
+        parsed_candidates.sort(
+            key=lambda item: (
+                _score_exam_payload_candidate(item[0]),
+                len(json.dumps(item[0], ensure_ascii=False)),
+            ),
+            reverse=True,
+        )
+        best_payload, best_raw_candidate = parsed_candidates[0]
+        if _score_exam_payload_candidate(best_payload) > 0:
+            return json.dumps(best_payload, ensure_ascii=False)
+        logger.warning(
+            "[exam] 已提取到JSON对象但试卷特征较弱,选择最大候选兜底: score=%s snippet=%s",
+            _score_exam_payload_candidate(best_payload),
+            (best_raw_candidate or "")[:200],
+        )
+        return json.dumps(best_payload, ensure_ascii=False)
+
+    logger.warning("[exam] 未能从模型响应中提取试卷 JSON,保留原始响应供前端兜底解析")
+    return raw_text
+
+
 def _normalize_related_question(question: str) -> str:
     if not isinstance(question, str):
         return ""
@@ -316,7 +616,8 @@ def _clean_safety_training_topic(message: str) -> str:
 def _parse_json_object(text: str) -> dict:
     if not text:
         return {}
-    cleaned = re.sub(r"```(?:json)?\s*", "", str(text)).replace("```", "").strip()
+    cleaned = re.sub(r"```(?:json)?\s*", "", str(text)
+                     ).replace("```", "").strip()
     match = re.search(r"\{.*\}", cleaned, re.DOTALL)
     if not match:
         return {}
@@ -355,11 +656,13 @@ def _normalize_safety_training_plan(message: str, raw_plan: dict) -> dict:
 
     focus = raw_plan.get("content_focus")
     if isinstance(focus, list):
-        normalized_focus = [str(item).strip() for item in focus if str(item).strip()]
+        normalized_focus = [str(item).strip()
+                            for item in focus if str(item).strip()]
         if normalized_focus:
             plan["content_focus"] = normalized_focus
     elif isinstance(focus, str) and focus.strip():
-        plan["content_focus"] = [item.strip() for item in re.split(r"[、,,;\n]", focus) if item.strip()]
+        plan["content_focus"] = [item.strip()
+                                 for item in re.split(r"[、,,;\n]", focus) if item.strip()]
 
     if "培训" not in plan["topic"]:
         plan["topic"] = f"{plan['topic']}安全培训"
@@ -412,7 +715,8 @@ async def _infer_safety_training_plan(message: str) -> dict:
         ])
         return _normalize_safety_training_plan(message, _parse_json_object(response))
     except Exception as e:
-        logger.warning(f"[safety_training] 需求整理失败,使用兜底结构: {type(e).__name__}: {e}")
+        logger.warning(
+            f"[safety_training] 需求整理失败,使用兜底结构: {type(e).__name__}: {e}")
         return _build_fallback_safety_training_plan(message)
 
 
@@ -421,9 +725,11 @@ def _clean_ai_writing_response(content: str) -> str:
     if not text:
         return ""
 
-    text = re.sub(r"```(?:html)?\s*", "", text, flags=re.IGNORECASE).replace("```", "").strip()
+    text = re.sub(r"```(?:html)?\s*", "", text,
+                  flags=re.IGNORECASE).replace("```", "").strip()
 
-    body_match = re.search(r"<body[^>]*>(.*?)</body>", text, re.IGNORECASE | re.DOTALL)
+    body_match = re.search(
+        r"<body[^>]*>(.*?)</body>", text, re.IGNORECASE | re.DOTALL)
     if body_match:
         text = body_match.group(1).strip()
 
@@ -494,7 +800,8 @@ async def _generate_ai_writing_response(message: str) -> str:
 
 async def _generate_ppt_outline_response(message: str) -> str:
     training_plan = await _infer_safety_training_plan(message)
-    generation_message = _build_safety_training_generation_message(message, training_plan)
+    generation_message = _build_safety_training_generation_message(
+        message, training_plan)
     rag_context = await _rag_search(generation_message, top_k=10)
     system_content = load_prompt(
         "ppt_outline",
@@ -676,7 +983,8 @@ async def send_deepseek_message(
                 ]
 
                 qwen_response = await qwen_service.chat(messages)
-                raw_thinking, raw_answer = split_thinking_and_answer(qwen_response)
+                raw_thinking, raw_answer = split_thinking_and_answer(
+                    qwen_response)
                 answer_source = raw_answer or qwen_response
 
                 # 兼容模型直接返回 JSON 的场景
@@ -706,8 +1014,10 @@ async def send_deepseek_message(
                 else:
                     response_text = answer_text
             except Exception as e:
-                error_detail = str(e).strip() if str(e).strip() else f"未知错误({type(e).__name__})"
-                logger.error(f"[send_deepseek_message] AI问答异常: {type(e).__name__}: {error_detail}")
+                error_detail = str(e).strip() if str(
+                    e).strip() else f"未知错误({type(e).__name__})"
+                logger.error(
+                    f"[send_deepseek_message] AI问答异常: {type(e).__name__}: {error_detail}")
                 response_text = f"处理失败: {error_detail}"
 
         elif data.business_type == 1:
@@ -740,8 +1050,10 @@ async def send_deepseek_message(
                     },
                 }
             except Exception as e:
-                error_detail = str(e).strip() if str(e).strip() else f"未知错误({type(e).__name__})"
-                logger.error(f"[send_deepseek_message] PPT大纲生成异常: {type(e).__name__}: {error_detail}")
+                error_detail = str(e).strip() if str(
+                    e).strip() else f"未知错误({type(e).__name__})"
+                logger.error(
+                    f"[send_deepseek_message] PPT大纲生成异常: {type(e).__name__}: {error_detail}")
                 response_text = f"处理失败: {error_detail}"
 
         elif data.business_type == 2:
@@ -774,8 +1086,10 @@ async def send_deepseek_message(
                     },
                 }
             except Exception as e:
-                error_detail = str(e).strip() if str(e).strip() else f"未知错误({type(e).__name__})"
-                logger.error(f"[send_deepseek_message] AI写作异常: {type(e).__name__}: {error_detail}")
+                error_detail = str(e).strip() if str(
+                    e).strip() else f"未知错误({type(e).__name__})"
+                logger.error(
+                    f"[send_deepseek_message] AI写作异常: {type(e).__name__}: {error_detail}")
                 response_text = f"处理失败: {error_detail}"
 
         elif data.business_type == 3:
@@ -800,7 +1114,8 @@ async def send_deepseek_message(
                     {"role": "user", "content": message},
                 ]
 
-                response_text = await qwen_service.chat(messages)
+                raw_response_text = await qwen_service.chat(messages)
+                response_text = _sanitize_exam_response(raw_response_text)
 
                 now_ts = int(time.time())
                 user_message = AIMessage(
@@ -839,8 +1154,10 @@ async def send_deepseek_message(
                     )
                     db.commit()
             except Exception as e:
-                error_detail = str(e).strip() if str(e).strip() else f"未知错误({type(e).__name__})"
-                logger.error(f"[send_deepseek_message] 考试工坊异常: {type(e).__name__}: {error_detail}")
+                error_detail = str(e).strip() if str(
+                    e).strip() else f"未知错误({type(e).__name__})"
+                logger.error(
+                    f"[send_deepseek_message] 考试工坊异常: {type(e).__name__}: {error_detail}")
                 response_text = f"处理失败: {error_detail}"
 
         else:
@@ -1084,7 +1401,8 @@ async def stream_chat(request: Request, data: StreamChatRequest):
             thinking_buf = ""
             in_think = False
             thinking_done = False
-            max_input_chars = getattr(settings.thinking_summary, "max_input_chars", 1500)
+            max_input_chars = getattr(
+                settings.thinking_summary, "max_input_chars", 1500)
 
             async for chunk in qwen_service.stream_chat(messages):
                 buffer += chunk
@@ -1107,13 +1425,15 @@ async def stream_chat(request: Request, data: StreamChatRequest):
                         end_idx = lower.find("</think>")
                         if end_idx == -1:
                             if max_input_chars and len(thinking_buf) < max_input_chars:
-                                thinking_buf += buffer[: max_input_chars - len(thinking_buf)]
+                                thinking_buf += buffer[: max_input_chars -
+                                                       len(thinking_buf)]
                             buffer = ""
                             break
 
                         if max_input_chars and len(thinking_buf) < max_input_chars:
                             thinking_part = buffer[:end_idx]
-                            thinking_buf += thinking_part[: max_input_chars - len(thinking_buf)]
+                            thinking_buf += thinking_part[: max_input_chars - len(
+                                thinking_buf)]
 
                         buffer = buffer[end_idx + len("</think>"):]
                         in_think = False
@@ -1329,7 +1649,8 @@ async def stream_chat_with_db(request: Request, data: StreamChatWithDBRequest):
                 if rag_context:
                     context_parts.append(f"??????\n{rag_context}")
                 if data.online_search_content:
-                    context_parts.append(f"???????\n{data.online_search_content}")
+                    context_parts.append(
+                        f"???????\n{data.online_search_content}")
 
                 context_json = "\n\n".join(
                     context_parts) if context_parts else "?????????"
@@ -1348,8 +1669,10 @@ async def stream_chat_with_db(request: Request, data: StreamChatWithDBRequest):
             # 8. 流式输出并收集完整回复
             full_response = ""
             try:
-                summary_enabled = getattr(settings.thinking_summary, "enabled", True)
-                max_input_chars = getattr(settings.thinking_summary, "max_input_chars", 1500)
+                summary_enabled = getattr(
+                    settings.thinking_summary, "enabled", True)
+                max_input_chars = getattr(
+                    settings.thinking_summary, "max_input_chars", 1500)
 
                 buffer = ""
                 pre_answer = ""
@@ -1378,22 +1701,24 @@ async def stream_chat_with_db(request: Request, data: StreamChatWithDBRequest):
                                     break
 
                                 pre_answer += buffer[:start_idx]
-                                buffer = buffer[start_idx + len("<think>") :]
+                                buffer = buffer[start_idx + len("<think>"):]
                                 in_think = True
                                 continue
 
                             end_idx = lower.find("</think>")
                             if end_idx == -1:
                                 if max_input_chars and len(thinking_buf) < max_input_chars:
-                                    thinking_buf += buffer[: max_input_chars - len(thinking_buf)]
+                                    thinking_buf += buffer[: max_input_chars -
+                                                           len(thinking_buf)]
                                 buffer = ""
                                 break
 
                             if max_input_chars and len(thinking_buf) < max_input_chars:
                                 thinking_part = buffer[:end_idx]
-                                thinking_buf += thinking_part[: max_input_chars - len(thinking_buf)]
+                                thinking_buf += thinking_part[: max_input_chars - len(
+                                    thinking_buf)]
 
-                            buffer = buffer[end_idx + len("</think>") :]
+                            buffer = buffer[end_idx + len("</think>"):]
                             in_think = False
                             thinking_done = True
 
@@ -1413,7 +1738,8 @@ async def stream_chat_with_db(request: Request, data: StreamChatWithDBRequest):
                             answer_chunk = (pre_answer + buffer).lstrip()
                             if answer_chunk:
                                 full_response += answer_chunk
-                                escaped_answer = answer_chunk.replace('\n', '\\n')
+                                escaped_answer = answer_chunk.replace(
+                                    '\n', '\\n')
                                 yield f"data: {escaped_answer}\n\n"
 
                             pre_answer = ""
@@ -1537,7 +1863,8 @@ async def guess_you_want(
             if not questions:
                 questions = ["该话题的具体应用场景?", "有哪些注意事项?", "相关案例分析?"]
 
-        questions = _finalize_related_questions(questions, ai_msg.content, limit=3)
+        questions = _finalize_related_questions(
+            questions, ai_msg.content, limit=3)
 
         guess_json = json.dumps({"questions": questions}, ensure_ascii=False)
 

+ 122 - 0
shudao-chat-py/tests/test_exam_response_sanitizer.py

@@ -0,0 +1,122 @@
+import importlib.util
+import json
+import unittest
+from pathlib import Path
+
+
+CHAT_PATH = Path(__file__).resolve().parents[1] / "routers" / "chat.py"
+spec = importlib.util.spec_from_file_location(
+    "chat_under_test_exam", CHAT_PATH)
+chat = importlib.util.module_from_spec(spec)
+spec.loader.exec_module(chat)
+
+
+def exam_payload(title="桩基础施工技术考核"):
+    return {
+        "title": title,
+        "totalScore": 100,
+        "totalQuestions": 1,
+        "singleChoice": {
+            "scorePerQuestion": 2,
+            "totalScore": 2,
+            "count": 1,
+            "questions": [
+                {
+                    "text": "钻孔灌注桩清孔完成后应重点检查哪项指标?",
+                    "options": [
+                        {"key": "A", "text": "孔底沉渣厚度"},
+                        {"key": "B", "text": "施工便道宽度"},
+                        {"key": "C", "text": "钢筋棚颜色"},
+                        {"key": "D", "text": "围挡广告内容"},
+                    ],
+                    "answer": "A",
+                    "analysis": "孔底沉渣厚度直接影响桩端承载力。",
+                }
+            ],
+        },
+        "judge": {"scorePerQuestion": 3, "totalScore": 0, "count": 0, "questions": []},
+        "multiple": {"scorePerQuestion": 5, "totalScore": 0, "count": 0, "questions": []},
+        "short": {"scorePerQuestion": 10, "totalScore": 0, "count": 0, "questions": []},
+    }
+
+
+class ExamResponseSanitizerTests(unittest.TestCase):
+    def test_removes_thinking_process_prefix(self):
+        raw = "Thinking Process:\n\n1. Analyze the Request.\n\n" + \
+            json.dumps(exam_payload(), ensure_ascii=False)
+
+        cleaned = chat._sanitize_exam_response(raw)
+        parsed = json.loads(cleaned)
+
+        self.assertEqual(parsed["title"], "桩基础施工技术考核")
+        self.assertIn("singleChoice", parsed)
+        self.assertNotIn("Thinking Process", cleaned)
+
+    def test_extracts_json_from_markdown_code_block(self):
+        raw = "下面是生成结果:\n```json\n" + \
+            json.dumps(exam_payload("桥梁考试"), ensure_ascii=False) + "\n```"
+
+        cleaned = chat._sanitize_exam_response(raw)
+        parsed = json.loads(cleaned)
+
+        self.assertEqual(parsed["title"], "桥梁考试")
+
+    def test_prefers_exam_payload_over_other_json_noise(self):
+        raw = (
+            "Thinking Process:\n"
+            '{"note":"not exam"}\n'
+            "Final Answer:\n"
+            + json.dumps(exam_payload("最终试卷"), ensure_ascii=False)
+        )
+
+        cleaned = chat._sanitize_exam_response(raw)
+        parsed = json.loads(cleaned)
+
+        self.assertEqual(parsed["title"], "最终试卷")
+        self.assertIn("singleChoice", parsed)
+
+    def test_extracts_exam_payload_when_reasoning_contains_quotes_and_examples(self):
+        raw = (
+            'Thinking Process:\n'
+            'The output must contain "title", "totalScore", "singleChoice".\n'
+            'Use {"key": "A", "text": "..."} as the option shape example.\n'
+            'Section example: {"scorePerQuestion": 2, "totalScore": 20, "count": 10, "questions": [...]}.\n\n'
+            + json.dumps(exam_payload("带说明的最终试卷"), ensure_ascii=False)
+        )
+
+        cleaned = chat._sanitize_exam_response(raw)
+        parsed = json.loads(cleaned)
+
+        self.assertEqual(parsed["title"], "带说明的最终试卷")
+        self.assertIn("singleChoice", parsed)
+        self.assertFalse(cleaned.startswith("Thinking Process"))
+
+    def test_extracts_trailing_exam_json_after_think_suffix(self):
+        raw = (
+            "Thinking Process:\n"
+            'Use {"key": "A", "text": "..."} as example.\n'
+            "</think>\n\n"
+            + json.dumps(exam_payload("尾部试卷"), ensure_ascii=False)
+        )
+
+        cleaned = chat._sanitize_exam_response(raw)
+        parsed = json.loads(cleaned)
+
+        self.assertEqual(parsed["title"], "尾部试卷")
+        self.assertEqual(parsed["totalQuestions"], 1)
+
+    def test_repairs_unescaped_quotes_inside_string_values(self):
+        payload = json.dumps(exam_payload("引号容错"), ensure_ascii=False)
+        payload = payload.replace(
+            "钻孔灌注桩清孔完成后应重点检查哪项指标?", '钻孔灌注桩必须实行"一炮三检"制度吗?')
+        payload = payload.replace("孔底沉渣厚度直接影响桩端承载力。", '"一炮三检"是爆破作业的常见安全检查制度。')
+
+        cleaned = chat._sanitize_exam_response(payload)
+        parsed = json.loads(cleaned)
+
+        self.assertEqual(parsed["title"], "引号容错")
+        self.assertIn('"一炮三检"', parsed["singleChoice"]["questions"][0]["text"])
+
+
+if __name__ == "__main__":
+    unittest.main()