from fastapi import APIRouter, Depends, Request from fastapi.responses import StreamingResponse, JSONResponse from sqlalchemy.orm import Session from pydantic import BaseModel from typing import Optional from database import get_db, SessionLocal from models.chat import AIConversation, AIMessage from models.total import RecommendQuestion from utils.config import settings from utils.logger import logger from services.qwen_service import qwen_service from utils.prompt_loader import load_prompt import time import json import httpx router = APIRouter() def _build_conversation_preview(content: str, limit: int = 50) -> str: content = (content or "").strip() if len(content) <= limit: return content return content[:limit] + "..." def _to_frontend_timestamp(timestamp: Optional[int]) -> Optional[int]: if not timestamp: return None return timestamp if timestamp >= 10**12 else timestamp * 1000 def _build_conversation_title(conversation: AIConversation) -> str: if conversation.business_type == 3 and (conversation.exam_name or "").strip(): return conversation.exam_name.strip() return _build_conversation_preview(conversation.content or "", limit=30) def _refresh_conversation_snapshot(db: Session, conversation_id: int, user_id: int) -> None: latest_message = ( db.query(AIMessage) .filter( AIMessage.ai_conversation_id == conversation_id, AIMessage.user_id == user_id, AIMessage.is_deleted == 0, ) .order_by(AIMessage.id.desc()) .first() ) if not latest_message: db.query(AIConversation).filter( AIConversation.id == conversation_id, AIConversation.user_id == user_id, ).update({"is_deleted": 1, "updated_at": int(time.time())}) return latest_user_message = ( db.query(AIMessage) .filter( AIMessage.ai_conversation_id == conversation_id, AIMessage.user_id == user_id, AIMessage.type == "user", AIMessage.is_deleted == 0, ) .order_by(AIMessage.id.desc()) .first() ) preview_source = ( latest_user_message.content if latest_user_message and latest_user_message.content else latest_message.content ) preview_content = _build_conversation_preview(preview_source or "", limit=100) db.query(AIConversation).filter( AIConversation.id == conversation_id, AIConversation.user_id == user_id, ).update( { "content": preview_content or " ", "updated_at": int(time.time()), } ) # ───────────────────────────────────────────────────────────────────────── # 辅助函数 # ───────────────────────────────────────────────────────────────────────── async def _rag_search(message: str, top_k: int = 5) -> str: """调用 search API 做 RAG 检索,返回上下文文本""" try: search_cfg = getattr(settings, 'search', None) if not search_cfg or not hasattr(search_cfg, 'api_url'): return "" search_url = search_cfg.api_url if not search_url: return "" async with httpx.AsyncClient(timeout=10.0) as client: resp = await client.post( search_url, json={"query": message, "n_results": top_k}, ) if resp.status_code == 200: data = resp.json() docs = data.get("results") or data.get("documents") or [] return "\n\n".join( d.get("content") or d.get("text") or str(d) for d in docs[:top_k] if d.get("content") or d.get("text") ) except Exception as e: logger.warning(f"[RAG] 检索失败(可忽略): {e}") return "" def _build_history_messages(conv_id: int, limit: int = 10) -> list: """从数据库读取最近对话历史,构建 messages 列表""" db = SessionLocal() try: msgs = ( db.query(AIMessage) .filter(AIMessage.ai_conversation_id == conv_id, AIMessage.is_deleted == 0) .order_by(AIMessage.id.desc()) .limit(limit) .all() ) msgs.reverse() result = [] for m in msgs: role = "user" if m.type == "user" else "assistant" if m.content: result.append({"role": role, "content": m.content}) return result finally: db.close() # ───────────────────────────────────────────────────────────────────────── # 非流式接口 # ───────────────────────────────────────────────────────────────────────── class SendMessageRequest(BaseModel): message: str conversation_id: Optional[int] = None business_type: int = 0 # 0=AI问答, 1=PPT大纲, 2=AI写作, 3=考试工坊 exam_name: str = "" ai_message_id: int = 0 @router.post("/send_deepseek_message") async def send_deepseek_message( request: Request, data: SendMessageRequest, db: Session = Depends(get_db), ): """ 发送消息(非流式) 支持多种业务类型: - 0: AI问答(意图识别 + RAG) - 1: PPT大纲生成 - 2: AI写作 - 3: 考试工坊 """ user = request.state.user if not user: return {"statusCode": 401, "msg": "未授权"} try: message = data.message.strip() if not message: return {"statusCode": 400, "msg": "消息不能为空"} # 创建或获取对话 if not data.conversation_id: conversation = AIConversation( user_id=user.user_id, content=message[:100], business_type=data.business_type, exam_name=data.exam_name if data.business_type == 3 else "", created_at=int(time.time()), updated_at=int(time.time()), is_deleted=0, ) db.add(conversation) db.commit() db.refresh(conversation) conv_id = conversation.id else: conv_id = data.conversation_id db.query(AIConversation).filter( AIConversation.id == conv_id, AIConversation.user_id == user.user_id, AIConversation.is_deleted == 0, ).update({ "content": message[:100], "business_type": data.business_type, "exam_name": data.exam_name if data.business_type == 3 else "", "updated_at": int(time.time()), }) db.commit() response_text = "" if data.business_type == 0: # AI问答:意图识别 + RAG try: intent_result = await qwen_service.intent_recognition(message) intent_type = "" if isinstance(intent_result, dict): intent_type = ( intent_result.get("intent_type") or intent_result.get("intent") or "" ).lower() rag_context = "" if intent_type in ("query_knowledge_base", "知识库查询", "技术咨询"): rag_context = await _rag_search(message, top_k=10) # 使用prompt加载器加载最终回答prompt system_content = load_prompt( "final_answer", userMessage=message, contextJSON=rag_context if rag_context else "暂无相关知识库内容" ) messages = [ {"role": "user", "content": system_content}, ] qwen_response = await qwen_service.chat(messages) try: if isinstance(qwen_response, str) and qwen_response.strip().startswith("{"): response_json = json.loads(qwen_response) response_text = response_json.get("natural_language_answer", qwen_response) else: response_text = qwen_response except Exception: response_text = qwen_response except Exception as e: logger.error(f"[send_deepseek_message] AI问答异常: {e}") response_text = f"处理失败: {str(e)}" elif data.business_type == 1: # PPT大纲生成 try: rag_context = await _rag_search(message, top_k=10) # 使用prompt加载器加载PPT大纲生成prompt system_content = load_prompt( "ppt_outline", userMessage=message, contextJSON=rag_context if rag_context else "暂无相关知识库内容" ) messages = [ {"role": "user", "content": system_content}, ] response_text = await qwen_service.chat(messages) except Exception as e: logger.error(f"[send_deepseek_message] PPT大纲生成异常: {e}") response_text = f"处理失败: {str(e)}" elif data.business_type == 2: # AI写作 try: rag_context = await _rag_search(message, top_k=10) # 使用prompt加载器加载公文写作prompt system_content = load_prompt( "document_writing", userMessage=message, contextJSON=rag_context if rag_context else "暂无相关知识库内容" ) messages = [ {"role": "user", "content": system_content}, ] response_text = await qwen_service.chat(messages) except Exception as e: logger.error(f"[send_deepseek_message] AI写作异常: {e}") response_text = f"处理失败: {str(e)}" elif data.business_type == 3: # 考试工坊:生成题目 try: system_content = ( "你是一个专业的考试题目生成助手,专注于路桥隧轨施工安全领域。\n" "请根据用户需求生成专业的考试题目,包括单选题、多选题、判断题等。\n" "每道题目应包含:题目内容、选项(如适用)、正确答案、解析。\n" "输出格式应为结构化的 JSON。" ) messages = [ {"role": "system", "content": system_content}, {"role": "user", "content": message}, ] response_text = await qwen_service.chat(messages) if data.exam_name: db.query(AIConversation).filter(AIConversation.id == conv_id).update( {"exam_name": data.exam_name, "updated_at": int(time.time())} ) db.commit() except Exception as e: logger.error(f"[send_deepseek_message] 考试工坊异常: {e}") response_text = f"处理失败: {str(e)}" else: return {"statusCode": 400, "msg": f"不支持的业务类型: {data.business_type}"} return { "statusCode": 200, "msg": "success", "data": { "conversation_id": conv_id, "response": response_text, "user_id": user.user_id, "business_type": data.business_type, }, } except Exception as e: logger.error(f"[send_deepseek_message] 异常: {e}") return {"statusCode": 500, "msg": f"处理失败: {str(e)}"} @router.get("/get_history_record") async def get_history_record( request: Request, ai_conversation_id: int = 0, business_type: Optional[int] = None, db: Session = Depends(get_db), ): """兼容前端的历史记录查询:ai_conversation_id=0 返回对话列表,否则返回消息详情。""" user = request.state.user if not user: return {"statusCode": 401, "msg": "未授权"} if ai_conversation_id > 0: messages = ( db.query(AIMessage) .filter( AIMessage.ai_conversation_id == ai_conversation_id, AIMessage.user_id == user.user_id, AIMessage.is_deleted == 0, ) .order_by(AIMessage.id.asc()) .all() ) return { "statusCode": 200, "msg": "success", "total": len(messages), "data": [ { "id": message.id, "ai_conversation_id": message.ai_conversation_id, "user_id": message.user_id, "type": message.type, "content": message.content, "user_feedback": message.user_feedback, "prev_user_id": message.prev_user_id, "search_source": message.search_source or "", "guess_you_want": message.guess_you_want or "", "created_at": _to_frontend_timestamp(message.created_at), "updated_at": _to_frontend_timestamp(message.updated_at), } for message in messages ], } conversations_query = db.query(AIConversation).filter( AIConversation.user_id == user.user_id, AIConversation.is_deleted == 0, ) if business_type is not None: conversations_query = conversations_query.filter( AIConversation.business_type == business_type ) total = conversations_query.count() conversations = ( conversations_query .order_by(AIConversation.updated_at.desc(), AIConversation.id.desc()) .limit(50) .all() ) return { "statusCode": 200, "msg": "success", "total": total, "data": [ { "id": conv.id, "title": _build_conversation_title(conv), "content": conv.content or "", "business_type": conv.business_type, "exam_name": conv.exam_name or "", "created_at": _to_frontend_timestamp(conv.created_at), "updated_at": _to_frontend_timestamp(conv.updated_at), } for conv in conversations ], } class DeleteConversationRequest(BaseModel): ai_conversation_id: int = 0 ai_message_id: int = 0 @router.post("/delete_conversation") async def delete_conversation( request: Request, data: DeleteConversationRequest, db: Session = Depends(get_db) ): """ 删除对话(软删除) 同时软删除对话记录和所有关联的消息 """ user = request.state.user if not user: return {"statusCode": 401, "msg": "未授权"} now_ts = int(time.time()) if data.ai_message_id: ai_message = ( db.query(AIMessage) .filter( AIMessage.id == data.ai_message_id, AIMessage.user_id == user.user_id, AIMessage.type == "ai", AIMessage.is_deleted == 0, ) .first() ) if not ai_message: return {"statusCode": 404, "msg": "消息不存在"} db.query(AIMessage).filter( AIMessage.id == ai_message.id, AIMessage.user_id == user.user_id, ).update({"is_deleted": 1, "updated_at": now_ts}) if ai_message.prev_user_id: db.query(AIMessage).filter( AIMessage.id == ai_message.prev_user_id, AIMessage.user_id == user.user_id, AIMessage.ai_conversation_id == ai_message.ai_conversation_id, ).update({"is_deleted": 1, "updated_at": now_ts}) _refresh_conversation_snapshot(db, ai_message.ai_conversation_id, user.user_id) db.commit() return {"statusCode": 200, "msg": "删除成功"} if not data.ai_conversation_id: return {"statusCode": 400, "msg": "缺少删除参数"} db.query(AIConversation).filter( AIConversation.id == data.ai_conversation_id, AIConversation.user_id == user.user_id, ).update({"is_deleted": 1, "updated_at": now_ts}) db.query(AIMessage).filter( AIMessage.ai_conversation_id == data.ai_conversation_id, AIMessage.user_id == user.user_id, ).update({"is_deleted": 1, "updated_at": now_ts}) db.commit() return {"statusCode": 200, "msg": "删除成功"} class DeleteHistoryRequest(BaseModel): ai_conversation_id: int @router.post("/delete_history_record") async def delete_history_record( request: Request, data: DeleteHistoryRequest, db: Session = Depends(get_db) ): """删除历史记录(软删除)""" user = request.state.user if not user: return {"statusCode": 401, "msg": "未授权"} db.query(AIConversation).filter( AIConversation.id == data.ai_conversation_id, AIConversation.user_id == user.user_id, ).update({"is_deleted": 1, "updated_at": int(time.time())}) db.commit() return {"statusCode": 200, "msg": "删除成功"} # ───────────────────────────────────────────────────────────────────────── # 流式接口 /stream/chat(无 DB,意图识别 + RAG) # ───────────────────────────────────────────────────────────────────────── class StreamChatRequest(BaseModel): message: str model: str = "" @router.post("/stream/chat") async def stream_chat(request: Request, data: StreamChatRequest): """流式聊天(SSE,不写 DB)""" message = data.message.strip() if not message: return JSONResponse(content={"statusCode": 400, "msg": "消息不能为空"}) async def event_generator(): intent_type = "" try: intent_result = await qwen_service.intent_recognition(message) if isinstance(intent_result, dict): intent_type = ( intent_result.get("intent_type") or intent_result.get("intent") or "" ).lower() except Exception as ie: logger.warning(f"[stream/chat] 意图识别异常: {ie}") rag_context = "" if intent_type in ("query_knowledge_base", "知识库查询", "技术咨询"): rag_context = await _rag_search(message) # 使用prompt加载器加载最终回答prompt system_content = load_prompt( "final_answer", userMessage=message, contextJSON=rag_context if rag_context else "暂无相关知识库内容" ) messages = [ {"role": "user", "content": system_content}, ] try: async for chunk in qwen_service.stream_chat(messages): yield f"data: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n" except Exception as e: logger.error(f"[stream/chat] 流式输出异常: {e}") yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" finally: yield "data: [DONE]\n\n" return StreamingResponse(event_generator(), media_type="text/event-stream") # ───────────────────────────────────────────────────────────────────────── # 流式接口 /stream/chat-with-db(前端主聊天接口) # ───────────────────────────────────────────────────────────────────────── class StreamChatWithDBRequest(BaseModel): message: str ai_conversation_id: int = 0 business_type: int = 0 exam_name: str = "" ai_message_id: int = 0 online_search_content: str = "" @router.post("/stream/chat-with-db") async def stream_chat_with_db(request: Request, data: StreamChatWithDBRequest): """ 带 DB 操作的流式聊天(SSE) 流程: 1. 创建/获取对话 2. 插入用户消息和 AI 占位消息 3. 发送 initial 事件 4. RAG 检索 5. 构建历史上下文 6. 流式输出 7. 更新 AI 消息内容 """ user = request.state.user if not user: return JSONResponse(content={"statusCode": 401, "msg": "未授权"}) message = data.message.strip() if not message: return JSONResponse(content={"statusCode": 400, "msg": "消息不能为空"}) async def event_generator(): db = SessionLocal() try: # 1. 创建或获取对话 if data.ai_conversation_id == 0: conversation = AIConversation( user_id=user.user_id, content=_build_conversation_preview(message, limit=100), business_type=data.business_type, exam_name=data.exam_name, created_at=int(time.time()), updated_at=int(time.time()), is_deleted=0, ) db.add(conversation) db.commit() db.refresh(conversation) conv_id = conversation.id else: existing_conversation = ( db.query(AIConversation) .filter( AIConversation.id == data.ai_conversation_id, AIConversation.user_id == user.user_id, AIConversation.is_deleted == 0, ) .first() ) if existing_conversation: conv_id = existing_conversation.id db.query(AIConversation).filter( AIConversation.id == conv_id, AIConversation.user_id == user.user_id, ).update( { "content": _build_conversation_preview(message, limit=100), "business_type": data.business_type, "exam_name": data.exam_name if data.business_type == 3 else "", "updated_at": int(time.time()), } ) db.commit() else: conversation = AIConversation( user_id=user.user_id, content=_build_conversation_preview(message, limit=100), business_type=data.business_type, exam_name=data.exam_name if data.business_type == 3 else "", created_at=int(time.time()), updated_at=int(time.time()), is_deleted=0, ) db.add(conversation) db.commit() db.refresh(conversation) conv_id = conversation.id # 2. 插入用户消息 user_msg = AIMessage( ai_conversation_id=conv_id, user_id=user.user_id, type="user", content=message, created_at=int(time.time()), updated_at=int(time.time()), is_deleted=0, ) db.add(user_msg) db.commit() db.refresh(user_msg) # 3. 插入 AI 占位消息 ai_msg = AIMessage( ai_conversation_id=conv_id, user_id=user.user_id, type="ai", content="", prev_user_id=user_msg.id, created_at=int(time.time()), updated_at=int(time.time()), is_deleted=0, ) db.add(ai_msg) db.commit() db.refresh(ai_msg) # 4. 发送 initial 事件 yield f"data: {json.dumps({'type': 'initial', 'ai_conversation_id': conv_id, 'ai_message_id': ai_msg.id}, ensure_ascii=False)}\n\n" # 5. RAG 检索 rag_context = await _rag_search(message, top_k=10) # 6. 获取历史上下文(最近 4 条,2 轮对话) history_msgs = ( db.query(AIMessage) .filter( AIMessage.ai_conversation_id == conv_id, AIMessage.id < ai_msg.id, AIMessage.is_deleted == 0, ) .order_by(AIMessage.updated_at.desc()) .limit(4) .all() ) history_msgs.reverse() history_context = "" for msg in history_msgs: role = "用户" if msg.type == "user" else "助手" history_context += f"{role}: {msg.content}\n\n" # 7. 构建完整 prompt # 构建上下文JSON context_parts = [] if rag_context: context_parts.append(f"知识库内容:\n{rag_context}") if data.online_search_content: context_parts.append(f"联网搜索结果:\n{data.online_search_content}") context_json = "\n\n".join(context_parts) if context_parts else "暂无相关知识库内容" # 使用prompt加载器加载最终回答prompt system_content = load_prompt( "final_answer", userMessage=message, contextJSON=context_json, historyContext=history_context if history_context else "" ) messages = [ {"role": "user", "content": system_content}, ] # 8. 流式输出并收集完整回复 full_response = "" try: async for chunk in qwen_service.stream_chat(messages): escaped_chunk = chunk.replace("\n", "\\n") full_response += chunk yield f"data: {escaped_chunk}\n\n" except Exception as e: logger.error(f"[stream/chat-with-db] 流式输出异常: {e}") yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" # 9. 更新 AI 消息内容 if full_response: now_ts = int(time.time()) db.query(AIMessage).filter(AIMessage.id == ai_msg.id).update( {"content": full_response, "updated_at": now_ts} ) db.query(AIConversation).filter( AIConversation.id == conv_id, AIConversation.user_id == user.user_id, ).update( { "content": _build_conversation_preview(message, limit=100), "business_type": data.business_type, "exam_name": data.exam_name if data.business_type == 3 else "", "updated_at": now_ts, } ) db.commit() # 10. 结束标记 yield "data: [DONE]\n\n" except Exception as e: logger.error(f"[stream/chat-with-db] 处理异常: {e}") yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n" finally: db.close() return StreamingResponse(event_generator(), media_type="text/event-stream") # ───────────────────────────────────────────────────────────────────────── # 猜你想问 # ───────────────────────────────────────────────────────────────────────── class GuessYouWantRequest(BaseModel): ai_message_id: int @router.post("/guess_you_want") async def guess_you_want( request: Request, data: GuessYouWantRequest, db: Session = Depends(get_db), ): """生成"猜你想问"的3个关联问题,保存到 AIMessage.guess_you_want""" user = request.state.user if not user: return {"statusCode": 401, "msg": "未授权"} try: ai_msg = ( db.query(AIMessage) .filter(AIMessage.id == data.ai_message_id, AIMessage.is_deleted == 0) .first() ) if not ai_msg: return {"statusCode": 404, "msg": "消息不存在"} # 使用prompt加载器加载猜你想问prompt system_content = load_prompt( "guess_questions", currentContent=ai_msg.content[:500] ) messages = [ {"role": "user", "content": system_content}, ] response = await qwen_service.chat(messages) try: # 尝试从响应中提取 JSON import re json_match = re.search(r'\{[^{}]*"questions"[^{}]*\}', response, re.DOTALL) if json_match: response_json = json.loads(json_match.group()) else: response_json = json.loads(response) questions = response_json.get("questions", []) except Exception: lines = [l.strip() for l in response.split("\n") if l.strip()] questions = [] for line in lines: clean = line.lstrip("0123456789.-、 ").strip() if clean and len(clean) > 5: questions.append(clean) if not questions: questions = ["该话题的具体应用场景?", "有哪些注意事项?", "相关案例分析?"] questions = questions[:3] while len(questions) < 3: questions.append("更多相关问题") guess_json = json.dumps({"questions": questions}, ensure_ascii=False) db.query(AIMessage).filter(AIMessage.id == data.ai_message_id).update( {"guess_you_want": guess_json, "updated_at": int(time.time())} ) db.commit() return { "statusCode": 200, "msg": "success", "data": {"ai_message_id": data.ai_message_id, "questions": questions}, } except Exception as e: logger.error(f"[guess_you_want] 处理异常: {e}") return {"statusCode": 500, "msg": f"处理失败: {str(e)}"} # ───────────────────────────────────────────────────────────────────────── # 在线搜索(Dify 工作流集成) # ───────────────────────────────────────────────────────────────────────── @router.get("/online_search") async def online_search(question: str, request: Request, db: Session = Depends(get_db)): """ 在线搜索 流程:Qwen 提炼关键词 → Dify 工作流 → 返回摘要 """ user = request.state.user if not user: return {"statusCode": 401, "msg": "未授权"} try: keywords = await qwen_service.extract_keywords(question) dify_config = getattr(settings, "dify", None) if not dify_config or not getattr(dify_config, "workflow_url", None): return {"statusCode": 500, "msg": "Dify 配置未设置"} headers = { "Authorization": f"Bearer {dify_config.auth_token}", "Content-Type": "application/json", } payload = { "workflow_id": dify_config.workflow_id, "inputs": { "keywords": keywords, "num": 5, # 搜索结果数量 "max_text_len": 4000 # 最大文本长度 }, "response_mode": "blocking", "user": getattr(user, "account", str(user.user_id)), } async with httpx.AsyncClient(timeout=30.0) as client: resp = await client.post(dify_config.workflow_url, headers=headers, json=payload) if resp.status_code != 200: logger.error(f"[online_search] Dify 调用失败: {resp.status_code}, 响应: {resp.text}") return {"statusCode": 500, "msg": f"搜索服务异常: {resp.status_code}"} result = resp.json() search_text = result.get("data", {}).get("outputs", {}).get("text", "") return { "statusCode": 200, "msg": "success", "data": {"keywords": keywords, "result": search_text}, } except Exception as e: logger.error(f"[online_search] 处理异常: {e}") return {"statusCode": 500, "msg": f"搜索失败: {str(e)}"} class SaveOnlineSearchResultRequest(BaseModel): ai_message_id: int search_result: str @router.post("/save_online_search_result") async def save_online_search_result( request: Request, data: SaveOnlineSearchResultRequest, db: Session = Depends(get_db), ): """保存联网搜索结果到 AIMessage.search_source""" user = request.state.user if not user: return {"statusCode": 401, "msg": "未授权"} try: db.query(AIMessage).filter(AIMessage.id == data.ai_message_id).update( {"search_source": data.search_result, "updated_at": int(time.time())} ) db.commit() return {"statusCode": 200, "msg": "保存成功"} except Exception as e: logger.error(f"[save_online_search_result] 处理异常: {e}") return {"statusCode": 500, "msg": f"保存失败: {str(e)}"} # ───────────────────────────────────────────────────────────────────────── # 意图识别独立接口 # ───────────────────────────────────────────────────────────────────────── class IntentRecognitionRequest(BaseModel): message: str save_to_db: bool = False ai_conversation_id: int = 0 @router.post("/intent_recognition") async def intent_recognition( request: Request, data: IntentRecognitionRequest, db: Session = Depends(get_db), ): """独立意图识别接口;若为 greeting/faq 且 save_to_db=True 则直接存 DB""" user = request.state.user if not user: return {"statusCode": 401, "msg": "未授权"} try: intent_result = await qwen_service.intent_recognition(data.message) intent_type = "" response_text = "" if isinstance(intent_result, dict): intent_type = ( intent_result.get("intent_type") or intent_result.get("intent") or "" ).lower() response_text = intent_result.get("response", "") if data.save_to_db and intent_type in ("greeting", "问候", "faq", "常见问题"): if data.ai_conversation_id == 0: conversation = AIConversation( user_id=user.user_id, content=data.message[:100], business_type=0, created_at=int(time.time()), updated_at=int(time.time()), is_deleted=0, ) db.add(conversation) db.commit() db.refresh(conversation) conv_id = conversation.id else: conv_id = data.ai_conversation_id user_msg = AIMessage( ai_conversation_id=conv_id, user_id=user.user_id, type="user", content=data.message, created_at=int(time.time()), updated_at=int(time.time()), is_deleted=0, ) db.add(user_msg) db.commit() ai_msg = AIMessage( ai_conversation_id=conv_id, user_id=user.user_id, type="ai", content=response_text, prev_user_id=user_msg.id, created_at=int(time.time()), updated_at=int(time.time()), is_deleted=0, ) db.add(ai_msg) db.commit() db.refresh(ai_msg) return { "statusCode": 200, "msg": "success", "data": { "intent_type": intent_type, "response": response_text, "ai_conversation_id": conv_id, "ai_message_id": ai_msg.id, "saved_to_db": True, }, } return { "statusCode": 200, "msg": "success", "data": { "intent_type": intent_type, "response": response_text, "saved_to_db": False, }, } except Exception as e: logger.error(f"[intent_recognition] 处理异常: {e}") return {"statusCode": 500, "msg": f"处理失败: {str(e)}"} # ───────────────────────────────────────────────────────────────────────── # 获取用户推荐问题(模糊查询 QA / RecommendQuestion 表) # ───────────────────────────────────────────────────────────────────────── @router.get("/get_user_recommend_question") async def get_user_recommend_question( keyword: str = "", limit: int = 10, db: Session = Depends(get_db), ): """获取推荐问题(支持模糊查询)""" try: query = db.query(RecommendQuestion).filter(RecommendQuestion.is_deleted == 0) if keyword: query = query.filter(RecommendQuestion.question.like(f"%{keyword}%")) questions = query.order_by(RecommendQuestion.id.desc()).limit(limit).all() return { "statusCode": 200, "msg": "success", "data": [ {"id": q.id, "question": q.question, "created_at": q.created_at} for q in questions ], } except Exception as e: logger.error(f"[get_user_recommend_question] 处理异常: {e}") return {"statusCode": 500, "msg": f"查询失败: {str(e)}"} # ───────────────────────────────────────────────────────────────────────── # PPT 大纲 / 文档编辑保存 # ───────────────────────────────────────────────────────────────────────── class SavePPTOutlineRequest(BaseModel): ai_message_id: int content: str @router.post("/save_ppt_outline") async def save_ppt_outline( request: Request, data: SavePPTOutlineRequest, db: Session = Depends(get_db), ): """更新 AIMessage.content 保存 PPT 大纲内容""" user = request.state.user if not user: return {"statusCode": 401, "msg": "未授权"} try: db.query(AIMessage).filter(AIMessage.id == data.ai_message_id).update( {"content": data.content, "updated_at": int(time.time())} ) db.commit() return {"statusCode": 200, "msg": "保存成功"} except Exception as e: logger.error(f"[save_ppt_outline] 处理异常: {e}") return {"statusCode": 500, "msg": f"保存失败: {str(e)}"} class SaveEditDocumentRequest(BaseModel): ai_message_id: int content: str @router.post("/save_edit_document") async def save_edit_document( request: Request, data: SaveEditDocumentRequest, db: Session = Depends(get_db), ): """更新 ai 类型 AIMessage.content(AI写作编辑保存)""" user = request.state.user if not user: return {"statusCode": 401, "msg": "未授权"} try: db.query(AIMessage).filter( AIMessage.id == data.ai_message_id, AIMessage.type == "ai", ).update({"content": data.content, "updated_at": int(time.time())}) db.commit() return {"statusCode": 200, "msg": "保存成功"} except Exception as e: logger.error(f"[save_edit_document] 处理异常: {e}") return {"statusCode": 500, "msg": f"保存失败: {str(e)}"}