from fastapi import APIRouter, Depends, Request from fastapi.responses import StreamingResponse, JSONResponse from sqlalchemy.orm import Session from pydantic import BaseModel from typing import Optional from database import get_db, SessionLocal from models.chat import AIConversation, AIMessage from models.total import RecommendQuestion from utils.config import settings from utils.logger import logger from services.qwen_service import qwen_service from services.deepseek_service import deepseek_service from utils.prompt_loader import load_prompt from utils.thinking_summary import split_thinking_and_answer, summarize_thinking_content import time import json import httpx import re router = APIRouter() def _build_conversation_preview(content: str, limit: int = 50) -> str: content = (content or "").strip() if len(content) <= limit: return content return content[:limit] + "..." def _to_frontend_timestamp(timestamp: Optional[int]) -> Optional[int]: if not timestamp: return None return timestamp if timestamp >= 10**12 else timestamp * 1000 def _build_conversation_title(conversation: AIConversation) -> str: if conversation.business_type == 3 and (conversation.exam_name or "").strip(): return conversation.exam_name.strip() return _build_conversation_preview(conversation.content or "", limit=30) def _extract_json_object_from_index(source: str, start_idx: int) -> str: if start_idx < 0 or start_idx >= len(source) or source[start_idx] != "{": return "" depth = 0 in_string = False escaped = False for idx in range(start_idx, len(source)): ch = source[idx] if escaped: escaped = False continue if in_string: if ch == "\\": escaped = True elif ch == '"': in_string = False continue if ch == '"': in_string = True continue if ch == "{": depth += 1 elif ch == "}": depth -= 1 if depth == 0: return source[start_idx: idx + 1] return "" def _extract_balanced_json_objects(text: str) -> list[str]: source = (text or "").strip() if not source: return [] objects = [] seen = set() for idx, ch in enumerate(source): if ch != "{": continue candidate = _extract_json_object_from_index(source, idx) if candidate and candidate not in seen: objects.append(candidate) seen.add(candidate) return objects def _extract_trailing_json_candidates(text: str) -> list[str]: source = (text or "").strip() if not source: return [] candidates = [] seen = set() line_start_indexes = [ match.start() for match in re.finditer(r"(?m)^[ \t]*\{", source) ] for start_idx in reversed(line_start_indexes): candidate = source[start_idx:].strip() if candidate and candidate not in seen: candidates.append(candidate) seen.add(candidate) return candidates def _extract_explicit_answer_segment(text: str) -> str: source = (text or "").strip() if not source: return "" markers = ( "final answer:", "final output:", "answer:", "output:", "json:", ) lowered = source.lower() for marker in markers: idx = lowered.rfind(marker) if idx >= 0: candidate = source[idx + len(marker):].strip() if candidate: return candidate return "" def _extract_brace_sliced_candidates(text: str) -> list[str]: source = (text or "").strip() if not source: return [] candidates = [] seen = set() first_brace = source.find("{") last_brace = source.rfind("}") if first_brace >= 0 and last_brace > first_brace: candidate = source[first_brace:last_brace + 1].strip() if candidate and candidate not in seen: candidates.append(candidate) seen.add(candidate) return candidates def _looks_like_exam_payload(payload: object) -> bool: if not isinstance(payload, dict): return False questions = payload.get("questions") return any( key in payload for key in ( "singleChoice", "single_choice", "单选题", "judge", "判断题", "multiple", "multiple_choice", "multipleChoice", "多选题", "short", "short_answer", "shortAnswer", "简答题", ) ) or ( isinstance(questions, dict) and any( key in questions for key in ( "singleChoice", "single_choice", "单选题", "judge", "判断题", "multiple", "multiple_choice", "multipleChoice", "多选题", "short", "short_answer", "shortAnswer", "简答题", ) ) ) def _score_exam_payload_candidate(payload: object) -> int: if not isinstance(payload, dict): return 0 score = 0 questions = payload.get("questions") if isinstance( payload.get("questions"), dict) else {} strong_keys = ( "singleChoice", "single_choice", "单选题", "judge", "判断题", "multiple", "multiple_choice", "multipleChoice", "多选题", "short", "short_answer", "shortAnswer", "简答题", ) weak_keys = ( "title", "exam_name", "examTitle", "试卷标题", "总分", "totalScore", "totalQuestions", ) score += sum(10 for key in strong_keys if key in payload) score += sum(8 for key in strong_keys if key in questions) score += sum(2 for key in weak_keys if key in payload) section_candidates = [] for _, value in payload.items(): if isinstance(value, dict): section_candidates.append(value) section_candidates.extend( value for value in questions.values() if isinstance(value, dict)) for section in section_candidates: if "questions" in section and isinstance(section.get("questions"), list): score += 6 question_list = section.get("questions") or [] if question_list and isinstance(question_list[0], dict): first_question = question_list[0] if any(k in first_question for k in ("text", "question_text", "question", "title", "content", "题干", "题目")): score += 4 if "options" in first_question: score += 3 if any(k in first_question for k in ("answer", "answers", "correct_answer", "correct_answers", "答案", "正确答案")): score += 3 if any(k in first_question for k in ("analysis", "explanation", "解析")): score += 2 if any(k in section for k in ("count", "question_count", "数量")): score += 2 if any(k in section for k in ("scorePerQuestion", "score_per_question", "每题分值")): score += 1 return score def _escape_inner_quotes_in_json(text: str) -> str: chars = [] in_string = False escaped = False for idx, ch in enumerate(text): if not in_string: chars.append(ch) if ch == '"': in_string = True escaped = False continue if escaped: chars.append(ch) escaped = False continue if ch == "\\": chars.append(ch) escaped = True continue if ch == '"': next_non_space = "" for next_idx in range(idx + 1, len(text)): if not text[next_idx].isspace(): next_non_space = text[next_idx] break if next_non_space in {",", "}", "]", ":"}: chars.append(ch) in_string = False else: chars.append('\\"') continue chars.append(ch) return "".join(chars) def _try_parse_exam_json(candidate: str) -> Optional[dict]: text = (candidate or "").strip() if not text: return None text = ( text.replace("\ufeff", "") .replace("```json", "") .replace("```JSON", "") .replace("```", "") .replace("“", '"') .replace("”", '"') ).strip() try: parsed = json.loads(text) except Exception: repaired_text = _escape_inner_quotes_in_json(text) repaired_text = re.sub(r",\s*([}\]])", r"\1", repaired_text) try: parsed = json.loads(repaired_text) except Exception: return None return parsed if _looks_like_exam_payload(parsed) else None def _sanitize_exam_response(raw_response: str) -> str: """考试工坊只向前端/数据库透传可 JSON.parse 的试卷 JSON。""" raw_text = (raw_response or "").strip() if not raw_text: return "" _, answer = split_thinking_and_answer(raw_text) explicit_answer = _extract_explicit_answer_segment(raw_text) for candidate in (answer, explicit_answer, raw_text): parsed = _try_parse_exam_json(candidate) if parsed: return json.dumps(parsed, ensure_ascii=False) parsed_candidates = [] for candidate in _extract_balanced_json_objects(raw_text): parsed = _try_parse_exam_json(candidate) if parsed: parsed_candidates.append((parsed, candidate)) for candidate in _extract_trailing_json_candidates(raw_text): parsed = _try_parse_exam_json(candidate) if parsed: parsed_candidates.append((parsed, candidate)) for candidate in _extract_brace_sliced_candidates(raw_text): parsed = _try_parse_exam_json(candidate) if parsed: parsed_candidates.append((parsed, candidate)) if parsed_candidates: parsed_candidates.sort( key=lambda item: ( _score_exam_payload_candidate(item[0]), len(json.dumps(item[0], ensure_ascii=False)), ), reverse=True, ) best_payload, best_raw_candidate = parsed_candidates[0] if _score_exam_payload_candidate(best_payload) > 0: return json.dumps(best_payload, ensure_ascii=False) logger.warning( "[exam] 已提取到JSON对象但试卷特征较弱,选择最大候选兜底: score=%s snippet=%s", _score_exam_payload_candidate(best_payload), (best_raw_candidate or "")[:200], ) return json.dumps(best_payload, ensure_ascii=False) logger.warning("[exam] 未能从模型响应中提取试卷 JSON,保留原始响应供前端兜底解析") return raw_text def _normalize_related_question(question: str) -> str: if not isinstance(question, str): return "" text = question.strip().strip('"').strip("'") text = re.sub(r"^[0-9]+[\.\)\]、]\s*", "", text) text = re.sub(r"^[-*]\s*", "", text) return text.strip() def _is_placeholder_related_question(question: str) -> bool: normalized = _normalize_related_question(question).lower() if not normalized: return True placeholder_patterns = ( r"^q\s*\d+$", r"^question\s*\d+$", r"^questions?\s*\d+$", r"^问题\s*\d+$", r"^相关问题\s*\d+$", r"^推荐问题\s*\d+$", r"^更多相关问题$", r"^更多问题$", ) return any(re.fullmatch(pattern, normalized) for pattern in placeholder_patterns) def _contains_chinese(text: str) -> bool: return any("\u4e00" <= char <= "\u9fff" for char in text or "") def _is_invalid_related_question(question: str) -> bool: normalized = _normalize_related_question(question) if ( not normalized or len(normalized) < 4 or _is_placeholder_related_question(normalized) or not _contains_chinese(normalized) ): return True lowered = normalized.lower() blocked_keywords = ( "thinking process", "analyze the request", "role:", "**role", "professional question recommendation", "infrastructure construction technology", "output format", "json", "prompt", "system", "assistant", "角色定义", "任务目标", "输入内容", "生成要求", "输出格式", "开始生成", ) return any(keyword in lowered for keyword in blocked_keywords) def _extract_related_question_topic(content: str) -> str: if not content: return "当前话题" text = re.sub(r"<[^>]+>", " ", str(content)) text = re.sub(r"\s+", " ", text).strip() text = re.sub( r"^(好的[!!,, ]*|我理解您提出的问题[,, ]*|这个问题[,, ]*|总的来说[::,, ]*)+", "", text, ) pattern = re.search( r"(?:主要围绕|围绕|关于|针对|聚焦)([^。!?\n,,;;]{4,32})", text, ) if pattern: topic = pattern.group(1).strip("“”\"' ::,,") if topic: return topic sentence = re.split(r"[。!?\n]", text, maxsplit=1)[0].strip("“”\"' ::,,") if sentence: return sentence[:24] return "当前话题" def _build_related_question_fallbacks(content: str) -> list[str]: topic = _extract_related_question_topic(content) return [ f"{topic}在现场实施时需要重点关注哪些风险点?", f"{topic}相关的方案编制、审批和验收要求有哪些?", f"针对{topic},日常检查和监测应抓住哪些关键指标?", ] def _finalize_related_questions(questions: list, content: str, limit: int = 3) -> list[str]: cleaned_questions = [] seen = set() for question in questions or []: normalized = _normalize_related_question(question) lowered = normalized.lower() if ( _is_invalid_related_question(normalized) or lowered in seen ): continue cleaned_questions.append(normalized) seen.add(lowered) if len(cleaned_questions) == limit: return cleaned_questions for fallback in _build_related_question_fallbacks(content): lowered = fallback.lower() if lowered in seen: continue cleaned_questions.append(fallback) seen.add(lowered) if len(cleaned_questions) == limit: break return cleaned_questions[:limit] def _refresh_conversation_snapshot(db: Session, conversation_id: int, user_id: int) -> None: latest_message = ( db.query(AIMessage) .filter( AIMessage.ai_conversation_id == conversation_id, AIMessage.user_id == user_id, AIMessage.is_deleted == 0, ) .order_by(AIMessage.id.desc()) .first() ) if not latest_message: db.query(AIConversation).filter( AIConversation.id == conversation_id, AIConversation.user_id == user_id, ).update({"is_deleted": 1, "updated_at": int(time.time())}) return latest_user_message = ( db.query(AIMessage) .filter( AIMessage.ai_conversation_id == conversation_id, AIMessage.user_id == user_id, AIMessage.type == "user", AIMessage.is_deleted == 0, ) .order_by(AIMessage.id.desc()) .first() ) preview_source = ( latest_user_message.content if latest_user_message and latest_user_message.content else latest_message.content ) preview_content = _build_conversation_preview( preview_source or "", limit=100) db.query(AIConversation).filter( AIConversation.id == conversation_id, AIConversation.user_id == user_id, ).update( { "content": preview_content or " ", "updated_at": int(time.time()), } ) # ───────────────────────────────────────────────────────────────────────── # 辅助函数 # ───────────────────────────────────────────────────────────────────────── async def _rag_search(message: str, top_k: int = 5) -> str: """调用 search API 做 RAG 检索,返回上下文文本""" try: search_cfg = getattr(settings, 'search', None) if not search_cfg or not hasattr(search_cfg, 'api_url'): return "" search_url = search_cfg.api_url if not search_url: return "" async with httpx.AsyncClient(timeout=10.0) as client: resp = await client.post( search_url, json={"query": message, "n_results": top_k}, ) if resp.status_code == 200: data = resp.json() docs = data.get("results") or data.get("documents") or [] return "\n\n".join( d.get("content") or d.get("text") or str(d) for d in docs[:top_k] if d.get("content") or d.get("text") ) except Exception as e: logger.warning(f"[RAG] 检索失败(可忽略): {e}") return "" SAFETY_TRAINING_PLAN_SYSTEM_PROMPT = """ 你是安全培训需求整理助手。请把用户的自然语言输入整理成安全培训PPT大纲生成任务。 规则: 1. 只输出一个 JSON 对象,不要输出 Markdown、解释或额外文字。 2. 即使用户说“通知”“材料”“文档”,也必须理解为安全培训模块中的 PPT 大纲需求,不要切换到其他文档生成任务。 3. 如果字段缺失,请根据安全培训场景合理补全,但不要编造具体制度编号、人员姓名或不存在的事实。 4. template 字段用于选择大纲模板,默认填“标准安全培训PPT大纲”。 5. content_focus 至少给出 3 个要点。 JSON 字段: { "topic": "培训主题", "template": "模板名称", "content_focus": ["内容要点1", "内容要点2", "内容要点3"], "audience": "参训对象", "time": "培训时间", "location": "培训地点", "goal": "培训目标", "notes": "其他要求", "normalized_request": "归一化后的安全培训PPT大纲生成需求" } """ def _extract_tag_value(message: str, tag: str) -> str: match = re.search(fr"<{tag}>(.*?)", message or "", re.DOTALL) return match.group(1).strip() if match else "" def _strip_document_tags(message: str) -> str: text = message or "" for tag in ("word", "filename", "filesize"): text = re.sub(fr"<{tag}>.*?", " ", text, flags=re.DOTALL) return re.sub(r"\s+", " ", text).strip() def _extract_safety_training_request_payload(message: str) -> dict: return { "document_content": _extract_tag_value(message, "word"), "filename": _extract_tag_value(message, "filename"), "filesize": _extract_tag_value(message, "filesize"), "request": _strip_document_tags(message), } def _clean_safety_training_topic(message: str) -> str: request_text = _extract_safety_training_request_payload(message)["request"] first_clause = re.split(r"[,。;;,\n]", request_text, maxsplit=1)[0].strip() topic = first_clause or request_text or "安全培训" for token in ("请", "帮我", "帮忙", "生成", "制作", "输出", "一份", "一个", "一下", "PPT大纲", "ppt大纲", "大纲", "通知", "文档", "材料"): topic = topic.replace(token, "") topic = re.sub(r"\s+", "", topic).strip(" ::,,。;;") if not topic: topic = "安全培训" if "培训" not in topic: topic = f"{topic}安全培训" return topic def _parse_json_object(text: str) -> dict: if not text: return {} cleaned = re.sub(r"```(?:json)?\s*", "", str(text) ).replace("```", "").strip() match = re.search(r"\{.*\}", cleaned, re.DOTALL) if not match: return {} try: parsed = json.loads(match.group(0)) return parsed if isinstance(parsed, dict) else {} except json.JSONDecodeError: return {} def _build_fallback_safety_training_plan(message: str) -> dict: topic = _clean_safety_training_topic(message) payload = _extract_safety_training_request_payload(message) return { "topic": topic, "template": "标准安全培训PPT大纲", "content_focus": ["安全生产责任", "现场风险识别", "安全意识提升", "培训纪律与行为规范"], "audience": "参训员工", "time": "", "location": "", "goal": "提升参训人员安全意识和施工现场风险防控能力", "notes": payload["request"], "normalized_request": f"围绕{topic}生成安全培训PPT大纲", } def _normalize_safety_training_plan(message: str, raw_plan: dict) -> dict: plan = _build_fallback_safety_training_plan(message) if not isinstance(raw_plan, dict): return plan for key in ("topic", "template", "audience", "time", "location", "goal", "notes", "normalized_request"): value = raw_plan.get(key) if isinstance(value, str) and value.strip(): plan[key] = value.strip() focus = raw_plan.get("content_focus") if isinstance(focus, list): normalized_focus = [str(item).strip() for item in focus if str(item).strip()] if normalized_focus: plan["content_focus"] = normalized_focus elif isinstance(focus, str) and focus.strip(): plan["content_focus"] = [item.strip() for item in re.split(r"[、,,;\n]", focus) if item.strip()] if "培训" not in plan["topic"]: plan["topic"] = f"{plan['topic']}安全培训" if "PPT大纲" not in plan["template"]: plan["template"] = f"{plan['template']}PPT大纲" return plan def _build_safety_training_generation_message(message: str, plan: dict) -> str: payload = _extract_safety_training_request_payload(message) focus_text = "、".join(plan.get("content_focus") or []) lines = [ "输出类型:安全培训PPT大纲", "请基于以下结构化需求生成安全培训PPT大纲,不要生成通知正文,不要切换到其他文档生成任务。", f"主题:{plan.get('topic') or '安全培训'}", f"模板:{plan.get('template') or '标准安全培训PPT大纲'}", f"内容要点:{focus_text or '安全生产责任、风险识别、应急处置、安全意识提升'}", f"参训对象:{plan.get('audience') or '参训员工'}", f"培训时间:{plan.get('time') or '未指定'}", f"培训地点:{plan.get('location') or '未指定'}", f"培训目标:{plan.get('goal') or '提升参训人员安全意识和风险防控能力'}", f"其他要求:{plan.get('notes') or '无'}", f"归一化需求:{plan.get('normalized_request') or ''}", f"原始需求:{payload['request'] or message}", ] if payload["filename"] or payload["document_content"]: lines.extend([ f"上传文档名称:{payload['filename'] or '未命名文档'}", f"上传文档大小:{payload['filesize'] or '未知'}", "上传文档内容:", payload["document_content"] or "无", ]) return "\n".join(lines) async def _infer_safety_training_plan(message: str) -> dict: payload = _extract_safety_training_request_payload(message) planning_input = payload["request"] or message if payload["document_content"]: planning_input = ( f"{planning_input}\n\n" f"上传文档名称:{payload['filename'] or '未命名文档'}\n" f"上传文档内容摘要:{payload['document_content'][:3000]}" ) try: response = await qwen_service.chat([ {"role": "system", "content": SAFETY_TRAINING_PLAN_SYSTEM_PROMPT}, {"role": "user", "content": planning_input}, ]) return _normalize_safety_training_plan(message, _parse_json_object(response)) except Exception as e: logger.warning( f"[safety_training] 需求整理失败,使用兜底结构: {type(e).__name__}: {e}") return _build_fallback_safety_training_plan(message) def _clean_ai_writing_response(content: str) -> str: text = str(content or "").strip() if not text: return "" text = re.sub(r"```(?:html)?\s*", "", text, flags=re.IGNORECASE).replace("```", "").strip() body_match = re.search( r"]*>(.*?)", text, re.IGNORECASE | re.DOTALL) if body_match: text = body_match.group(1).strip() first_content_tag = re.search( r"<(?:article|section|main|div|h[1-6]|p|table|ul|ol)\b", text, re.IGNORECASE, ) if first_content_tag and text[:first_content_tag.start()].strip(): text = text[first_content_tag.start():] cleanup_patterns = ( r"]*>", r"]*>", r"", r"]*>.*?", r"]*>", r"", r"]*>.*?", r"]*>.*?", r"]*>", r"]*>.*?", ) for pattern in cleanup_patterns: text = re.sub(pattern, "", text, flags=re.IGNORECASE | re.DOTALL) return text.strip() async def _generate_ai_writing_response(message: str) -> str: rag_context = await _rag_search(message, top_k=10) system_content = load_prompt( "document_writing", userMessage=message, contextJSON=rag_context if rag_context else "暂无相关知识库内容", ) messages = [ {"role": "system", "content": system_content}, { "role": "user", "content": ( "请根据上面的写作规范和我的原始需求,直接生成可放入富文本编辑器的公文正文 HTML 片段。" "不要输出道歉、解释、DOCTYPE、html、head、body、style 或 script 标签。\n\n" f"原始需求:\n{message}" ), }, ] raw_response = await deepseek_service.chat(messages) raw_thinking, raw_answer = split_thinking_and_answer(raw_response) answer_text = _clean_ai_writing_response(raw_answer or raw_response) # AI写作输出纯HTML文档内容,不附加思考过程(避免混入纯文本破坏HTML结构) return answer_text async def _generate_ppt_outline_response(message: str) -> str: training_plan = await _infer_safety_training_plan(message) generation_message = _build_safety_training_generation_message( message, training_plan) rag_context = await _rag_search(generation_message, top_k=10) system_content = load_prompt( "ppt_outline", userMessage=generation_message, contextJSON=rag_context if rag_context else "暂无相关知识库内容", ) messages = [ {"role": "system", "content": system_content}, {"role": "user", "content": "请直接输出安全培训PPT大纲正文,从标题开始,不要解释提示词或安全规则。"}, ] raw_response = await qwen_service.chat(messages) raw_thinking, raw_answer = split_thinking_and_answer(raw_response) answer_text = raw_answer or raw_response if raw_thinking: thinking_summary = await summarize_thinking_content( user_question=message, raw_thinking=raw_thinking, final_answer=answer_text, chat_service=qwen_service, context="ppt_outline", ) return ( f"思考过程:\n{thinking_summary}\n\n回答:\n{answer_text}" if thinking_summary else answer_text ) return answer_text def _persist_message_pair(db: Session, conv_id: int, user, user_content: str, ai_content: str): now_ts = int(time.time()) user_message = AIMessage( ai_conversation_id=conv_id, user_id=user.user_id, type="user", content=user_content, created_at=now_ts, updated_at=now_ts, is_deleted=0, ) db.add(user_message) db.commit() db.refresh(user_message) ai_message = AIMessage( ai_conversation_id=conv_id, user_id=user.user_id, type="ai", content=ai_content, prev_user_id=user_message.id, created_at=now_ts, updated_at=now_ts, is_deleted=0, ) db.add(ai_message) db.commit() db.refresh(ai_message) return user_message, ai_message def _build_history_messages(conv_id: int, limit: int = 10) -> list: """从数据库读取最近对话历史,构建 messages 列表""" db = SessionLocal() try: msgs = ( db.query(AIMessage) .filter(AIMessage.ai_conversation_id == conv_id, AIMessage.is_deleted == 0) .order_by(AIMessage.id.desc()) .limit(limit) .all() ) msgs.reverse() result = [] for m in msgs: role = "user" if m.type == "user" else "assistant" if m.content: result.append({"role": role, "content": m.content}) return result finally: db.close() # ───────────────────────────────────────────────────────────────────────── # 非流式接口 # ───────────────────────────────────────────────────────────────────────── class SendMessageRequest(BaseModel): message: str conversation_id: Optional[int] = None ai_conversation_id: Optional[int] = None business_type: int = 0 # 0=AI助手, 1=PPT大纲, 2=AI写作, 3=考试工坊 exam_name: str = "" ai_message_id: int = 0 @router.post("/send_deepseek_message") async def send_deepseek_message( request: Request, data: SendMessageRequest, db: Session = Depends(get_db), ): """ 发送消息(非流式) 支持多种业务类型: - 0: AI助手(意图识别 + RAG) - 1: PPT大纲生成 - 2: AI写作 - 3: 考试工坊 """ user = request.state.user if not user: return {"statusCode": 401, "msg": "未授权"} try: message = data.message.strip() if not message: return {"statusCode": 400, "msg": "消息不能为空"} conversation_id = data.conversation_id or data.ai_conversation_id # 创建或获取对话 if not conversation_id: conversation = AIConversation( user_id=user.user_id, content=message[:100], business_type=data.business_type, exam_name=data.exam_name if data.business_type == 3 else "", created_at=int(time.time()), updated_at=int(time.time()), is_deleted=0, ) db.add(conversation) db.commit() db.refresh(conversation) conv_id = conversation.id else: conv_id = conversation_id db.query(AIConversation).filter( AIConversation.id == conv_id, AIConversation.user_id == user.user_id, AIConversation.is_deleted == 0, ).update({ "content": message[:100], "business_type": data.business_type, "exam_name": data.exam_name if data.business_type == 3 else "", "updated_at": int(time.time()), }) db.commit() response_text = "" ai_message_id = 0 if data.business_type == 0: # AI助手:意图识别 + RAG try: intent_result = await qwen_service.intent_recognition(message) intent_type = "" if isinstance(intent_result, dict): intent_type = ( intent_result.get("intent_type") or intent_result.get( "intent") or "" ).lower() rag_context = "" if intent_type in ("query_knowledge_base", "知识库查询", "技术咨询"): rag_context = await _rag_search(message, top_k=10) # 使用prompt加载器加载最终回答prompt system_content = load_prompt( "final_answer", userMessage=message, contextJSON=rag_context if rag_context else "暂无相关知识库内容" ) messages = [ {"role": "user", "content": system_content}, ] qwen_response = await qwen_service.chat(messages) raw_thinking, raw_answer = split_thinking_and_answer( qwen_response) answer_source = raw_answer or qwen_response # 兼容模型直接返回 JSON 的场景 answer_text = answer_source try: if isinstance(answer_source, str) and answer_source.strip().startswith("{"): response_json = json.loads(answer_source) answer_text = response_json.get( "natural_language_answer", answer_source ) except Exception: answer_text = answer_source if raw_thinking: thinking_summary = await summarize_thinking_content( user_question=message, raw_thinking=raw_thinking, final_answer=answer_text, chat_service=qwen_service, context="send_message", ) response_text = ( f"思考过程:\n{thinking_summary}\n\n回答:\n{answer_text}" if thinking_summary else answer_text ) else: response_text = answer_text except Exception as e: error_detail = str(e).strip() if str( e).strip() else f"未知错误({type(e).__name__})" logger.error( f"[send_deepseek_message] AI助手异常: {type(e).__name__}: {error_detail}") response_text = f"处理失败: {error_detail}" elif data.business_type == 1: # PPT大纲生成 try: response_text = await _generate_ppt_outline_response(message) _, ai_message = _persist_message_pair( db=db, conv_id=conv_id, user=user, user_content=message, ai_content=response_text, ) ai_message_id = ai_message.id _refresh_conversation_snapshot(db, conv_id, user.user_id) db.commit() return { "statusCode": 200, "msg": "success", "data": { "conversation_id": conv_id, "ai_conversation_id": conv_id, "response": response_text, "reply": response_text, "content": response_text, "message": response_text, "ai_message_id": ai_message_id, "user_id": user.user_id, "business_type": data.business_type, }, } except Exception as e: error_detail = str(e).strip() if str( e).strip() else f"未知错误({type(e).__name__})" logger.error( f"[send_deepseek_message] PPT大纲生成异常: {type(e).__name__}: {error_detail}") response_text = f"处理失败: {error_detail}" elif data.business_type == 2: # AI写作 try: response_text = await _generate_ai_writing_response(message) _, ai_message = _persist_message_pair( db=db, conv_id=conv_id, user=user, user_content=message, ai_content=response_text, ) ai_message_id = ai_message.id _refresh_conversation_snapshot(db, conv_id, user.user_id) db.commit() return { "statusCode": 200, "msg": "success", "data": { "conversation_id": conv_id, "ai_conversation_id": conv_id, "response": response_text, "reply": response_text, "content": response_text, "message": response_text, "ai_message_id": ai_message_id, "user_id": user.user_id, "business_type": data.business_type, }, } except Exception as e: error_detail = str(e).strip() if str( e).strip() else f"未知错误({type(e).__name__})" logger.error( f"[send_deepseek_message] AI写作异常: {type(e).__name__}: {error_detail}") response_text = f"处理失败: {error_detail}" elif data.business_type == 3: # 考试工坊:生成题目 try: system_content = ( "你是一个专业的考试题目生成助手,专注于路桥隧轨施工安全领域。\n" "请根据用户需求生成专业的考试题目,包括单选题、多选题、判断题、简答题等。\n" "用户消息中已经包含考试标题、题型要求和出题依据内容,必须以其中的出题依据内容为核心生成题目,不能脱离依据内容自由发挥。\n" "题干、选项、答案和解析都要与出题依据内容中的知识点、专业术语、操作流程、规范要求或培训主题直接相关。\n" "输出必须是可直接 JSON.parse 的纯 JSON,不要包含 markdown 代码块、解释文字或额外前后缀。\n" "JSON 顶层结构必须包含 singleChoice、judge、multiple、short 四个字段。\n" "singleChoice.questions 和 multiple.questions 中每道题必须包含 text、options、answer、analysis。\n" "options 必须是数组,元素格式为 {\"key\":\"A\",\"text\":\"具体选项内容\"},禁止输出“选项A”这类占位文本。\n" "judge.questions 中每道题必须包含 text、answer、analysis。\n" "short.questions 中每道题必须包含 text、outline,其中 outline 至少包含 keyFactors。\n" "所有题目内容、选项内容、答案和解析都要结合用户给出的工程类型、题型数量、分值和课件内容具体生成。" ) messages = [ {"role": "system", "content": system_content}, {"role": "user", "content": message}, ] raw_response_text = await qwen_service.chat(messages) response_text = _sanitize_exam_response(raw_response_text) now_ts = int(time.time()) user_message = AIMessage( ai_conversation_id=conv_id, user_id=user.user_id, type="user", content=message, created_at=now_ts, updated_at=now_ts, is_deleted=0, ) db.add(user_message) db.commit() db.refresh(user_message) ai_message = AIMessage( ai_conversation_id=conv_id, user_id=user.user_id, type="ai", content=response_text, prev_user_id=user_message.id, created_at=now_ts, updated_at=now_ts, is_deleted=0, ) db.add(ai_message) db.commit() _refresh_conversation_snapshot(db, conv_id, user.user_id) db.commit() generated_title = data.exam_name if not generated_title: try: import json exam_data = json.loads(response_text) generated_title = exam_data.get("title") or exam_data.get("exam_name") or exam_data.get( "examTitle") or exam_data.get("试卷标题") or exam_data.get("标题") or "" except Exception: pass if generated_title: db.query(AIConversation).filter(AIConversation.id == conv_id).update( {"exam_name": generated_title, "updated_at": int(time.time())} ) db.commit() except Exception as e: error_detail = str(e).strip() if str( e).strip() else f"未知错误({type(e).__name__})" logger.error( f"[send_deepseek_message] 考试工坊异常: {type(e).__name__}: {error_detail}") response_text = f"处理失败: {error_detail}" else: return {"statusCode": 400, "msg": f"不支持的业务类型: {data.business_type}"} return { "statusCode": 200, "msg": "success", "data": { "conversation_id": conv_id, "ai_conversation_id": conv_id, "response": response_text, "reply": response_text, "content": response_text, "message": response_text, "user_id": user.user_id, "business_type": data.business_type, }, } except Exception as e: logger.error(f"[send_deepseek_message] 异常: {e}") return {"statusCode": 500, "msg": f"处理失败: {str(e)}"} @router.get("/get_history_record") async def get_history_record( request: Request, ai_conversation_id: int = 0, business_type: Optional[int] = None, db: Session = Depends(get_db), ): """兼容前端的历史记录查询:ai_conversation_id=0 返回对话列表,否则返回消息详情。""" user = request.state.user if not user: return {"statusCode": 401, "msg": "未授权"} if ai_conversation_id > 0: messages = ( db.query(AIMessage) .filter( AIMessage.ai_conversation_id == ai_conversation_id, AIMessage.user_id == user.user_id, AIMessage.is_deleted == 0, ) .order_by(AIMessage.id.asc()) .all() ) return { "statusCode": 200, "msg": "success", "total": len(messages), "data": [ { "id": message.id, "ai_conversation_id": message.ai_conversation_id, "user_id": message.user_id, "type": message.type, "content": message.content, "user_feedback": message.user_feedback, "prev_user_id": message.prev_user_id, "search_source": message.search_source or "", "guess_you_want": message.guess_you_want or "", "created_at": _to_frontend_timestamp(message.created_at), "updated_at": _to_frontend_timestamp(message.updated_at), } for message in messages ], } conversations_query = db.query(AIConversation).filter( AIConversation.user_id == user.user_id, AIConversation.is_deleted == 0, ) if business_type is not None: conversations_query = conversations_query.filter( AIConversation.business_type == business_type ) total = conversations_query.count() conversations = ( conversations_query .order_by(AIConversation.updated_at.desc(), AIConversation.id.desc()) .limit(50) .all() ) return { "statusCode": 200, "msg": "success", "total": total, "data": [ { "id": conv.id, "title": _build_conversation_title(conv), "content": conv.content or "", "business_type": conv.business_type, "exam_name": conv.exam_name or "", "created_at": _to_frontend_timestamp(conv.created_at), "updated_at": _to_frontend_timestamp(conv.updated_at), } for conv in conversations ], } class DeleteConversationRequest(BaseModel): ai_conversation_id: int = 0 ai_message_id: int = 0 @router.post("/delete_conversation") async def delete_conversation( request: Request, data: DeleteConversationRequest, db: Session = Depends(get_db) ): """ 删除对话(软删除) 同时软删除对话记录和所有关联的消息 """ user = request.state.user if not user: return {"statusCode": 401, "msg": "未授权"} now_ts = int(time.time()) if data.ai_message_id: ai_message = ( db.query(AIMessage) .filter( AIMessage.id == data.ai_message_id, AIMessage.user_id == user.user_id, AIMessage.type == "ai", AIMessage.is_deleted == 0, ) .first() ) if not ai_message: return {"statusCode": 404, "msg": "消息不存在"} db.query(AIMessage).filter( AIMessage.id == ai_message.id, AIMessage.user_id == user.user_id, ).update({"is_deleted": 1, "updated_at": now_ts}) if ai_message.prev_user_id: db.query(AIMessage).filter( AIMessage.id == ai_message.prev_user_id, AIMessage.user_id == user.user_id, AIMessage.ai_conversation_id == ai_message.ai_conversation_id, ).update({"is_deleted": 1, "updated_at": now_ts}) _refresh_conversation_snapshot( db, ai_message.ai_conversation_id, user.user_id) db.commit() return {"statusCode": 200, "msg": "删除成功"} if not data.ai_conversation_id: return {"statusCode": 400, "msg": "缺少删除参数"} db.query(AIConversation).filter( AIConversation.id == data.ai_conversation_id, AIConversation.user_id == user.user_id, ).update({"is_deleted": 1, "updated_at": now_ts}) db.query(AIMessage).filter( AIMessage.ai_conversation_id == data.ai_conversation_id, AIMessage.user_id == user.user_id, ).update({"is_deleted": 1, "updated_at": now_ts}) db.commit() return {"statusCode": 200, "msg": "删除成功"} class DeleteHistoryRequest(BaseModel): ai_conversation_id: int @router.post("/delete_history_record") async def delete_history_record( request: Request, data: DeleteHistoryRequest, db: Session = Depends(get_db) ): """删除历史记录(软删除)""" user = request.state.user if not user: return {"statusCode": 401, "msg": "未授权"} db.query(AIConversation).filter( AIConversation.id == data.ai_conversation_id, AIConversation.user_id == user.user_id, ).update({"is_deleted": 1, "updated_at": int(time.time())}) db.commit() return {"statusCode": 200, "msg": "删除成功"} # ───────────────────────────────────────────────────────────────────────── # 流式接口 /stream/chat(无 DB,意图识别 + RAG) # ───────────────────────────────────────────────────────────────────────── class StreamChatRequest(BaseModel): message: str model: str = "" @router.post("/stream/chat") async def stream_chat(request: Request, data: StreamChatRequest): """流式聊天(SSE,不写 DB)""" message = data.message.strip() if not message: return JSONResponse(content={"statusCode": 400, "msg": "消息不能为空"}) async def event_generator(): intent_type = "" try: intent_result = await qwen_service.intent_recognition(message) if isinstance(intent_result, dict): intent_type = ( intent_result.get("intent_type") or intent_result.get( "intent") or "" ).lower() except Exception as ie: logger.warning(f"[stream/chat] 意图识别异常: {ie}") rag_context = "" if intent_type in ("query_knowledge_base", "知识库查询", "技术咨询"): rag_context = await _rag_search(message) # 使用prompt加载器加载最终回答prompt system_content = load_prompt( "final_answer", userMessage=message, contextJSON=rag_context if rag_context else "暂无相关知识库内容" ) messages = [ {"role": "user", "content": system_content}, ] try: buffer = "" pre_answer = "" thinking_buf = "" in_think = False thinking_done = False max_input_chars = getattr( settings.thinking_summary, "max_input_chars", 1500) async for chunk in qwen_service.stream_chat(messages): buffer += chunk while buffer: lower = buffer.lower() if not thinking_done: if not in_think: start_idx = lower.find("") if start_idx == -1: yield f"data: {json.dumps({'content': buffer}, ensure_ascii=False)}\n\n" buffer = "" break pre_answer += buffer[:start_idx] buffer = buffer[start_idx + len(""):] in_think = True continue end_idx = lower.find("") if end_idx == -1: if max_input_chars and len(thinking_buf) < max_input_chars: thinking_buf += buffer[: max_input_chars - len(thinking_buf)] buffer = "" break if max_input_chars and len(thinking_buf) < max_input_chars: thinking_part = buffer[:end_idx] thinking_buf += thinking_part[: max_input_chars - len( thinking_buf)] buffer = buffer[end_idx + len(""):] in_think = False thinking_done = True thinking_summary = await summarize_thinking_content( user_question=message, raw_thinking=thinking_buf, final_answer="", chat_service=qwen_service, context="stream_chat", ) if thinking_summary: prefix = f"思考过程:\n{thinking_summary}\n\n回答:\n" yield f"data: {json.dumps({'content': prefix}, ensure_ascii=False)}\n\n" answer_chunk = (pre_answer + buffer).lstrip() if answer_chunk: yield f"data: {json.dumps({'content': answer_chunk}, ensure_ascii=False)}\n\n" pre_answer = "" buffer = "" break yield f"data: {json.dumps({'content': buffer}, ensure_ascii=False)}\n\n" buffer = "" # 流结束但未遇到 :仅尝试生成要点(不回退输出 raw thinking) if in_think and not thinking_done and thinking_buf: thinking_summary = await summarize_thinking_content( user_question=message, raw_thinking=thinking_buf, final_answer="", chat_service=qwen_service, context="stream_chat_eof", ) if thinking_summary: prefix = f"思考过程:\n{thinking_summary}\n\n回答:\n" yield f"data: {json.dumps({'content': prefix}, ensure_ascii=False)}\n\n" if pre_answer: yield f"data: {json.dumps({'content': pre_answer}, ensure_ascii=False)}\n\n" except Exception as e: logger.error(f"[stream/chat] 流式输出异常: {e}") yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" finally: yield "data: [DONE]\n\n" return StreamingResponse(event_generator(), media_type="text/event-stream") # ───────────────────────────────────────────────────────────────────────── # 流式接口 /stream/chat-with-db(前端主聊天接口) # ───────────────────────────────────────────────────────────────────────── class StreamChatWithDBRequest(BaseModel): message: str ai_conversation_id: int = 0 business_type: int = 0 exam_name: str = "" ai_message_id: int = 0 online_search_content: str = "" @router.post("/stream/chat-with-db") async def stream_chat_with_db(request: Request, data: StreamChatWithDBRequest): """ 带 DB 操作的流式聊天(SSE) 流程: 1. 创建/获取对话 2. 插入用户消息和 AI 占位消息 3. 发送 initial 事件 4. RAG 检索 5. 构建历史上下文 6. 流式输出 7. 更新 AI 消息内容 """ user = request.state.user if not user: return JSONResponse(content={"statusCode": 401, "msg": "未授权"}) message = data.message.strip() if not message: return JSONResponse(content={"statusCode": 400, "msg": "消息不能为空"}) async def event_generator(): db = SessionLocal() try: # 1. 创建或获取对话 if data.ai_conversation_id == 0: conversation = AIConversation( user_id=user.user_id, content=_build_conversation_preview(message, limit=100), business_type=data.business_type, exam_name=data.exam_name, created_at=int(time.time()), updated_at=int(time.time()), is_deleted=0, ) db.add(conversation) db.commit() db.refresh(conversation) conv_id = conversation.id else: existing_conversation = ( db.query(AIConversation) .filter( AIConversation.id == data.ai_conversation_id, AIConversation.user_id == user.user_id, AIConversation.is_deleted == 0, ) .first() ) if existing_conversation: conv_id = existing_conversation.id db.query(AIConversation).filter( AIConversation.id == conv_id, AIConversation.user_id == user.user_id, ).update( { "content": _build_conversation_preview(message, limit=100), "business_type": data.business_type, "exam_name": data.exam_name if data.business_type == 3 else "", "updated_at": int(time.time()), } ) db.commit() else: conversation = AIConversation( user_id=user.user_id, content=_build_conversation_preview( message, limit=100), business_type=data.business_type, exam_name=data.exam_name if data.business_type == 3 else "", created_at=int(time.time()), updated_at=int(time.time()), is_deleted=0, ) db.add(conversation) db.commit() db.refresh(conversation) conv_id = conversation.id # 2. 插入用户消息 user_msg = AIMessage( ai_conversation_id=conv_id, user_id=user.user_id, type="user", content=message, created_at=int(time.time()), updated_at=int(time.time()), is_deleted=0, ) db.add(user_msg) db.commit() db.refresh(user_msg) # 3. 插入 AI 占位消息 ai_msg = AIMessage( ai_conversation_id=conv_id, user_id=user.user_id, type="ai", content="", prev_user_id=user_msg.id, created_at=int(time.time()), updated_at=int(time.time()), is_deleted=0, ) db.add(ai_msg) db.commit() db.refresh(ai_msg) # 4. 发送 initial 事件 yield f"data: {json.dumps({'type': 'initial', 'ai_conversation_id': conv_id, 'ai_message_id': ai_msg.id}, ensure_ascii=False)}\n\n" # 5. RAG search rag_context = await _rag_search(message, top_k=10) if data.business_type in (1, 2): # PPT outline / AI writing: use dedicated prompt prompt_name = "ppt_outline" if data.business_type == 1 else "document_writing" system_content = load_prompt( prompt_name, userMessage=message, contextJSON=rag_context if rag_context else "?????????" ) messages = [ {"role": "user", "content": system_content}, ] else: # 6. History context (last 4 items, 2 turns) history_msgs = ( db.query(AIMessage) .filter( AIMessage.ai_conversation_id == conv_id, AIMessage.id < ai_msg.id, AIMessage.is_deleted == 0, ) .order_by(AIMessage.updated_at.desc()) .limit(4) .all() ) history_msgs.reverse() history_context = "" for msg in history_msgs: role = "??" if msg.type == "user" else "??" history_context += f"{role}: {msg.content}\n\n" # 7. Build final prompt context_parts = [] if rag_context: context_parts.append(f"??????\n{rag_context}") if data.online_search_content: context_parts.append( f"???????\n{data.online_search_content}") context_json = "\n\n".join( context_parts) if context_parts else "?????????" system_content = load_prompt( "final_answer", userMessage=message, contextJSON=context_json, historyContext=history_context if history_context else "" ) messages = [ {"role": "user", "content": system_content}, ] # 8. 流式输出并收集完整回复 full_response = "" try: summary_enabled = getattr( settings.thinking_summary, "enabled", True) max_input_chars = getattr( settings.thinking_summary, "max_input_chars", 1500) buffer = "" pre_answer = "" thinking_buf = "" in_think = False thinking_done = False async for chunk in qwen_service.stream_chat(messages): if not summary_enabled: escaped_chunk = chunk.replace("\n", "\\n") full_response += chunk yield f"data: {escaped_chunk}\n\n" continue buffer += chunk while buffer: lower = buffer.lower() if not thinking_done: if not in_think: start_idx = lower.find("") if start_idx == -1: escaped_text = buffer.replace("\n", "\\n") full_response += buffer yield f"data: {escaped_text}\n\n" buffer = "" break pre_answer += buffer[:start_idx] buffer = buffer[start_idx + len(""):] in_think = True continue end_idx = lower.find("") if end_idx == -1: if max_input_chars and len(thinking_buf) < max_input_chars: thinking_buf += buffer[: max_input_chars - len(thinking_buf)] buffer = "" break if max_input_chars and len(thinking_buf) < max_input_chars: thinking_part = buffer[:end_idx] thinking_buf += thinking_part[: max_input_chars - len( thinking_buf)] buffer = buffer[end_idx + len(""):] in_think = False thinking_done = True thinking_summary = await summarize_thinking_content( user_question=message, raw_thinking=thinking_buf, final_answer="", chat_service=qwen_service, context="stream_chat_with_db", ) if thinking_summary: prefix = f"思考过程:\n{thinking_summary}\n\n回答:\n" full_response += prefix escaped_prefix = prefix.replace('\n', '\\n') yield f"data: {escaped_prefix}\n\n" answer_chunk = (pre_answer + buffer).lstrip() if answer_chunk: full_response += answer_chunk escaped_answer = answer_chunk.replace( '\n', '\\n') yield f"data: {escaped_answer}\n\n" pre_answer = "" buffer = "" break escaped_text = buffer.replace("\n", "\\n") full_response += buffer yield f"data: {escaped_text}\n\n" buffer = "" except Exception as e: logger.error(f"[stream/chat-with-db] 流式输出异常: {e}") yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" # 流结束但未遇到 :仅尝试生成要点(不回退输出 raw thinking) if summary_enabled and in_think and not thinking_done and thinking_buf: thinking_summary = await summarize_thinking_content( user_question=message, raw_thinking=thinking_buf, final_answer="", chat_service=qwen_service, context="stream_chat_with_db_eof", ) if thinking_summary: prefix = f"思考过程:\n{thinking_summary}\n\n回答:\n" full_response += prefix escaped_prefix = prefix.replace('\n', '\\n') yield f"data: {escaped_prefix}\n\n" if pre_answer: full_response += pre_answer escaped_pre_answer = pre_answer.replace('\n', '\\n') yield f"data: {escaped_pre_answer}\n\n" # 9. 更新 AI 消息内容 if full_response: now_ts = int(time.time()) db.query(AIMessage).filter(AIMessage.id == ai_msg.id).update( {"content": full_response, "updated_at": now_ts} ) db.query(AIConversation).filter( AIConversation.id == conv_id, AIConversation.user_id == user.user_id, ).update( { "content": _build_conversation_preview(message, limit=100), "business_type": data.business_type, "exam_name": data.exam_name if data.business_type == 3 else "", "updated_at": now_ts, } ) db.commit() # 10. 结束标记 yield "data: [DONE]\n\n" except Exception as e: logger.error(f"[stream/chat-with-db] 处理异常: {e}") yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" finally: db.close() return StreamingResponse(event_generator(), media_type="text/event-stream") # ───────────────────────────────────────────────────────────────────────── # 猜你想问 # ───────────────────────────────────────────────────────────────────────── class GuessYouWantRequest(BaseModel): ai_message_id: int @router.post("/guess_you_want") async def guess_you_want( request: Request, data: GuessYouWantRequest, db: Session = Depends(get_db), ): """生成"猜你想问"的3个关联问题,保存到 AIMessage.guess_you_want""" user = request.state.user if not user: return {"statusCode": 401, "msg": "未授权"} try: ai_msg = ( db.query(AIMessage) .filter(AIMessage.id == data.ai_message_id, AIMessage.is_deleted == 0) .first() ) if not ai_msg: return {"statusCode": 404, "msg": "消息不存在"} # 使用prompt加载器加载猜你想问prompt system_content = load_prompt( "guess_questions", currentContent=ai_msg.content[:500] ) messages = [ {"role": "user", "content": system_content}, ] response = await qwen_service.chat(messages) try: # 尝试从响应中提取 JSON json_match = re.search( r'\{[^{}]*"questions"[^{}]*\}', response, re.DOTALL) if json_match: response_json = json.loads(json_match.group()) else: response_json = json.loads(response) questions = response_json.get("questions", []) except Exception: lines = [l.strip() for l in response.split("\n") if l.strip()] questions = [] for line in lines: clean = line.lstrip("0123456789.-、 ").strip() if clean and len(clean) > 5: questions.append(clean) if not questions: questions = ["该话题的具体应用场景?", "有哪些注意事项?", "相关案例分析?"] questions = _finalize_related_questions( questions, ai_msg.content, limit=3) guess_json = json.dumps({"questions": questions}, ensure_ascii=False) db.query(AIMessage).filter(AIMessage.id == data.ai_message_id).update( {"guess_you_want": guess_json, "updated_at": int(time.time())} ) db.commit() return { "statusCode": 200, "msg": "success", "data": {"ai_message_id": data.ai_message_id, "questions": questions}, } except Exception as e: logger.error(f"[guess_you_want] 处理异常: {e}") return {"statusCode": 500, "msg": f"处理失败: {str(e)}"} # ───────────────────────────────────────────────────────────────────────── # 在线搜索(Dify 工作流集成) # ───────────────────────────────────────────────────────────────────────── @router.get("/online_search") async def online_search(question: str, request: Request, db: Session = Depends(get_db)): """ 在线搜索 流程:Qwen 提炼关键词 → Dify 工作流 → 返回摘要 """ user = request.state.user if not user: return {"statusCode": 401, "msg": "未授权"} try: keywords = await qwen_service.extract_keywords(question) dify_config = getattr(settings, "dify", None) if not dify_config or not getattr(dify_config, "workflow_url", None): return {"statusCode": 500, "msg": "Dify 配置未设置"} headers = { "Authorization": f"Bearer {dify_config.auth_token}", "Content-Type": "application/json", } payload = { "workflow_id": dify_config.workflow_id, "inputs": { "keywords": keywords, "num": 5, # 搜索结果数量 "max_text_len": 4000 # 最大文本长度 }, "response_mode": "blocking", "user": getattr(user, "account", str(user.user_id)), } async with httpx.AsyncClient(timeout=30.0) as client: resp = await client.post(dify_config.workflow_url, headers=headers, json=payload) if resp.status_code != 200: logger.error( f"[online_search] Dify 调用失败: {resp.status_code}, 响应: {resp.text}") return {"statusCode": 500, "msg": f"搜索服务异常: {resp.status_code}"} result = resp.json() search_text = result.get("data", {}).get( "outputs", {}).get("text", "") return { "statusCode": 200, "msg": "success", "data": {"keywords": keywords, "result": search_text}, } except Exception as e: logger.error(f"[online_search] 处理异常: {e}") return {"statusCode": 500, "msg": f"搜索失败: {str(e)}"} class SaveOnlineSearchResultRequest(BaseModel): ai_message_id: int search_result: str @router.post("/save_online_search_result") async def save_online_search_result( request: Request, data: SaveOnlineSearchResultRequest, db: Session = Depends(get_db), ): """保存联网搜索结果到 AIMessage.search_source""" user = request.state.user if not user: return {"statusCode": 401, "msg": "未授权"} try: db.query(AIMessage).filter(AIMessage.id == data.ai_message_id).update( {"search_source": data.search_result, "updated_at": int(time.time())} ) db.commit() return {"statusCode": 200, "msg": "保存成功"} except Exception as e: logger.error(f"[save_online_search_result] 处理异常: {e}") return {"statusCode": 500, "msg": f"保存失败: {str(e)}"} # ───────────────────────────────────────────────────────────────────────── # 意图识别独立接口 # ───────────────────────────────────────────────────────────────────────── class IntentRecognitionRequest(BaseModel): message: str save_to_db: bool = False ai_conversation_id: int = 0 scene: str = "default" @router.post("/intent_recognition") async def intent_recognition( request: Request, data: IntentRecognitionRequest, db: Session = Depends(get_db), ): """独立意图识别接口;若为 greeting/faq 且 save_to_db=True 则直接存 DB""" user = request.state.user if not user: return {"statusCode": 401, "msg": "未授权"} try: if (data.scene or "").strip().lower() == "module_dispatch": dispatch_result = await qwen_service.module_dispatch_recognition(data.message) return { "statusCode": 200, "msg": "success", "data": { "route_mode": dispatch_result.get("route_mode", "ai-qa"), "business_type": dispatch_result.get("business_type", 0), "confidence": dispatch_result.get("confidence", 0.5), "reason": dispatch_result.get("reason", ""), "saved_to_db": False, }, } intent_result = await qwen_service.intent_recognition(data.message) intent_type = "" response_text = "" if isinstance(intent_result, dict): intent_type = ( intent_result.get("intent_type") or intent_result.get( "intent") or "" ).lower() response_text = intent_result.get("response", "") if data.save_to_db and intent_type in ("greeting", "问候", "faq", "常见问题"): if data.ai_conversation_id == 0: conversation = AIConversation( user_id=user.user_id, content=data.message[:100], business_type=0, created_at=int(time.time()), updated_at=int(time.time()), is_deleted=0, ) db.add(conversation) db.commit() db.refresh(conversation) conv_id = conversation.id else: conv_id = data.ai_conversation_id user_msg = AIMessage( ai_conversation_id=conv_id, user_id=user.user_id, type="user", content=data.message, created_at=int(time.time()), updated_at=int(time.time()), is_deleted=0, ) db.add(user_msg) db.commit() ai_msg = AIMessage( ai_conversation_id=conv_id, user_id=user.user_id, type="ai", content=response_text, prev_user_id=user_msg.id, created_at=int(time.time()), updated_at=int(time.time()), is_deleted=0, ) db.add(ai_msg) db.commit() db.refresh(ai_msg) return { "statusCode": 200, "msg": "success", "data": { "intent_type": intent_type, "response": response_text, "ai_conversation_id": conv_id, "ai_message_id": ai_msg.id, "saved_to_db": True, }, } return { "statusCode": 200, "msg": "success", "data": { "intent_type": intent_type, "response": response_text, "saved_to_db": False, }, } except Exception as e: logger.error(f"[intent_recognition] 处理异常: {e}") return {"statusCode": 500, "msg": f"处理失败: {str(e)}"} # ───────────────────────────────────────────────────────────────────────── # 获取用户推荐问题(模糊查询 QA / RecommendQuestion 表) # ───────────────────────────────────────────────────────────────────────── @router.get("/get_user_recommend_question") async def get_user_recommend_question( keyword: str = "", limit: int = 10, db: Session = Depends(get_db), ): """获取推荐问题(支持模糊查询)""" try: query = db.query(RecommendQuestion).filter( RecommendQuestion.is_deleted == 0) if keyword: query = query.filter( RecommendQuestion.question.like(f"%{keyword}%")) questions = query.order_by( RecommendQuestion.id.desc()).limit(limit).all() return { "statusCode": 200, "msg": "success", "data": [ {"id": q.id, "question": q.question, "created_at": q.created_at} for q in questions ], } except Exception as e: logger.error(f"[get_user_recommend_question] 处理异常: {e}") return {"statusCode": 500, "msg": f"查询失败: {str(e)}"} # ───────────────────────────────────────────────────────────────────────── # PPT 大纲 / 文档编辑保存 # ───────────────────────────────────────────────────────────────────────── class SavePPTOutlineRequest(BaseModel): ai_message_id: int content: str @router.post("/save_ppt_outline") async def save_ppt_outline( request: Request, data: SavePPTOutlineRequest, db: Session = Depends(get_db), ): """更新 AIMessage.content 保存 PPT 大纲内容""" user = request.state.user if not user: return {"statusCode": 401, "msg": "未授权"} try: db.query(AIMessage).filter(AIMessage.id == data.ai_message_id).update( {"content": data.content, "updated_at": int(time.time())} ) db.commit() return {"statusCode": 200, "msg": "保存成功"} except Exception as e: logger.error(f"[save_ppt_outline] 处理异常: {e}") return {"statusCode": 500, "msg": f"保存失败: {str(e)}"} class SaveEditDocumentRequest(BaseModel): ai_message_id: int content: str @router.post("/save_edit_document") async def save_edit_document( request: Request, data: SaveEditDocumentRequest, db: Session = Depends(get_db), ): """更新 ai 类型 AIMessage.content(AI写作编辑保存)""" user = request.state.user if not user: return {"statusCode": 401, "msg": "未授权"} try: db.query(AIMessage).filter( AIMessage.id == data.ai_message_id, AIMessage.type == "ai", ).update({"content": data.content, "updated_at": int(time.time())}) db.commit() return {"statusCode": 200, "msg": "保存成功"} except Exception as e: logger.error(f"[save_edit_document] 处理异常: {e}") return {"statusCode": 500, "msg": f"保存失败: {str(e)}"}