|
|
@@ -0,0 +1,901 @@
|
|
|
+from fastapi import APIRouter, Depends, Request
|
|
|
+from fastapi.responses import StreamingResponse, JSONResponse
|
|
|
+from sqlalchemy.orm import Session
|
|
|
+from pydantic import BaseModel
|
|
|
+from typing import Optional
|
|
|
+from database import get_db, SessionLocal
|
|
|
+from models.chat import AIConversation, AIMessage
|
|
|
+from models.total import RecommendQuestion
|
|
|
+from utils.config import settings
|
|
|
+from utils.logger import logger
|
|
|
+from services.qwen_service import qwen_service
|
|
|
+from utils.prompt_loader import load_prompt
|
|
|
+import time
|
|
|
+import json
|
|
|
+import httpx
|
|
|
+
|
|
|
+router = APIRouter()
|
|
|
+
|
|
|
+
|
|
|
+# ─────────────────────────────────────────────────────────────────────────
|
|
|
+# 辅助函数
|
|
|
+# ─────────────────────────────────────────────────────────────────────────
|
|
|
+
|
|
|
+async def _rag_search(message: str, top_k: int = 5) -> str:
|
|
|
+ """调用 search API 做 RAG 检索,返回上下文文本"""
|
|
|
+ try:
|
|
|
+ search_cfg = getattr(settings, 'search', None)
|
|
|
+ if not search_cfg or not hasattr(search_cfg, 'api_url'):
|
|
|
+ return ""
|
|
|
+ search_url = search_cfg.api_url
|
|
|
+ if not search_url:
|
|
|
+ return ""
|
|
|
+ async with httpx.AsyncClient(timeout=10.0) as client:
|
|
|
+ resp = await client.post(
|
|
|
+ search_url,
|
|
|
+ json={"query": message, "n_results": top_k},
|
|
|
+ )
|
|
|
+ if resp.status_code == 200:
|
|
|
+ data = resp.json()
|
|
|
+ docs = data.get("results") or data.get("documents") or []
|
|
|
+ return "\n\n".join(
|
|
|
+ d.get("content") or d.get("text") or str(d)
|
|
|
+ for d in docs[:top_k]
|
|
|
+ if d.get("content") or d.get("text")
|
|
|
+ )
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"[RAG] 检索失败(可忽略): {e}")
|
|
|
+ return ""
|
|
|
+
|
|
|
+
|
|
|
+def _build_history_messages(conv_id: int, limit: int = 10) -> list:
|
|
|
+ """从数据库读取最近对话历史,构建 messages 列表"""
|
|
|
+ db = SessionLocal()
|
|
|
+ try:
|
|
|
+ msgs = (
|
|
|
+ db.query(AIMessage)
|
|
|
+ .filter(AIMessage.ai_conversation_id == conv_id, AIMessage.is_deleted == 0)
|
|
|
+ .order_by(AIMessage.id.desc())
|
|
|
+ .limit(limit)
|
|
|
+ .all()
|
|
|
+ )
|
|
|
+ msgs.reverse()
|
|
|
+ result = []
|
|
|
+ for m in msgs:
|
|
|
+ role = "user" if m.type == "user" else "assistant"
|
|
|
+ if m.content:
|
|
|
+ result.append({"role": role, "content": m.content})
|
|
|
+ return result
|
|
|
+ finally:
|
|
|
+ db.close()
|
|
|
+
|
|
|
+
|
|
|
+# ─────────────────────────────────────────────────────────────────────────
|
|
|
+# 非流式接口
|
|
|
+# ─────────────────────────────────────────────────────────────────────────
|
|
|
+
|
|
|
+class SendMessageRequest(BaseModel):
|
|
|
+ message: str
|
|
|
+ conversation_id: Optional[int] = None
|
|
|
+ business_type: int = 0 # 0=AI问答, 1=PPT大纲, 2=AI写作, 3=考试工坊
|
|
|
+ exam_name: str = ""
|
|
|
+ ai_message_id: int = 0
|
|
|
+
|
|
|
+
|
|
|
+@router.post("/send_deepseek_message")
|
|
|
+async def send_deepseek_message(
|
|
|
+ request: Request,
|
|
|
+ data: SendMessageRequest,
|
|
|
+ db: Session = Depends(get_db),
|
|
|
+):
|
|
|
+ """
|
|
|
+ 发送消息(非流式)
|
|
|
+ 支持多种业务类型:
|
|
|
+ - 0: AI问答(意图识别 + RAG)
|
|
|
+ - 1: PPT大纲生成
|
|
|
+ - 2: AI写作
|
|
|
+ - 3: 考试工坊
|
|
|
+ """
|
|
|
+ user = request.state.user
|
|
|
+ if not user:
|
|
|
+ return {"statusCode": 401, "msg": "未授权"}
|
|
|
+
|
|
|
+ try:
|
|
|
+ message = data.message.strip()
|
|
|
+ if not message:
|
|
|
+ return {"statusCode": 400, "msg": "消息不能为空"}
|
|
|
+
|
|
|
+ # 创建或获取对话
|
|
|
+ if not data.conversation_id:
|
|
|
+ conversation = AIConversation(
|
|
|
+ user_id=user.user_id,
|
|
|
+ content=message[:100],
|
|
|
+ business_type=data.business_type,
|
|
|
+ exam_name=data.exam_name if data.business_type == 3 else "",
|
|
|
+ created_at=int(time.time()),
|
|
|
+ updated_at=int(time.time()),
|
|
|
+ is_deleted=0,
|
|
|
+ )
|
|
|
+ db.add(conversation)
|
|
|
+ db.commit()
|
|
|
+ db.refresh(conversation)
|
|
|
+ conv_id = conversation.id
|
|
|
+ else:
|
|
|
+ conv_id = data.conversation_id
|
|
|
+
|
|
|
+ response_text = ""
|
|
|
+
|
|
|
+ if data.business_type == 0:
|
|
|
+ # AI问答:意图识别 + RAG
|
|
|
+ try:
|
|
|
+ intent_result = await qwen_service.intent_recognition(message)
|
|
|
+ intent_type = ""
|
|
|
+ if isinstance(intent_result, dict):
|
|
|
+ intent_type = (
|
|
|
+ intent_result.get("intent_type") or intent_result.get("intent") or ""
|
|
|
+ ).lower()
|
|
|
+
|
|
|
+ rag_context = ""
|
|
|
+ if intent_type in ("query_knowledge_base", "知识库查询", "技术咨询"):
|
|
|
+ rag_context = await _rag_search(message, top_k=10)
|
|
|
+
|
|
|
+ # 使用prompt加载器加载最终回答prompt
|
|
|
+ system_content = load_prompt(
|
|
|
+ "final_answer",
|
|
|
+ userMessage=message,
|
|
|
+ contextJSON=rag_context if rag_context else "暂无相关知识库内容"
|
|
|
+ )
|
|
|
+
|
|
|
+ messages = [
|
|
|
+ {"role": "user", "content": system_content},
|
|
|
+ ]
|
|
|
+
|
|
|
+ qwen_response = await qwen_service.chat(messages)
|
|
|
+
|
|
|
+ try:
|
|
|
+ if isinstance(qwen_response, str) and qwen_response.strip().startswith("{"):
|
|
|
+ response_json = json.loads(qwen_response)
|
|
|
+ response_text = response_json.get("natural_language_answer", qwen_response)
|
|
|
+ else:
|
|
|
+ response_text = qwen_response
|
|
|
+ except Exception:
|
|
|
+ response_text = qwen_response
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"[send_deepseek_message] AI问答异常: {e}")
|
|
|
+ response_text = f"处理失败: {str(e)}"
|
|
|
+
|
|
|
+ elif data.business_type == 1:
|
|
|
+ # PPT大纲生成
|
|
|
+ try:
|
|
|
+ rag_context = await _rag_search(message, top_k=10)
|
|
|
+
|
|
|
+ # 使用prompt加载器加载PPT大纲生成prompt
|
|
|
+ system_content = load_prompt(
|
|
|
+ "ppt_outline",
|
|
|
+ userMessage=message,
|
|
|
+ contextJSON=rag_context if rag_context else "暂无相关知识库内容"
|
|
|
+ )
|
|
|
+
|
|
|
+ messages = [
|
|
|
+ {"role": "user", "content": system_content},
|
|
|
+ ]
|
|
|
+
|
|
|
+ response_text = await qwen_service.chat(messages)
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"[send_deepseek_message] PPT大纲生成异常: {e}")
|
|
|
+ response_text = f"处理失败: {str(e)}"
|
|
|
+
|
|
|
+ elif data.business_type == 2:
|
|
|
+ # AI写作
|
|
|
+ try:
|
|
|
+ rag_context = await _rag_search(message, top_k=10)
|
|
|
+
|
|
|
+ # 使用prompt加载器加载公文写作prompt
|
|
|
+ system_content = load_prompt(
|
|
|
+ "document_writing",
|
|
|
+ userMessage=message,
|
|
|
+ contextJSON=rag_context if rag_context else "暂无相关知识库内容"
|
|
|
+ )
|
|
|
+
|
|
|
+ messages = [
|
|
|
+ {"role": "user", "content": system_content},
|
|
|
+ ]
|
|
|
+
|
|
|
+ response_text = await qwen_service.chat(messages)
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"[send_deepseek_message] AI写作异常: {e}")
|
|
|
+ response_text = f"处理失败: {str(e)}"
|
|
|
+
|
|
|
+ elif data.business_type == 3:
|
|
|
+ # 考试工坊:生成题目
|
|
|
+ try:
|
|
|
+ system_content = (
|
|
|
+ "你是一个专业的考试题目生成助手,专注于路桥隧轨施工安全领域。\n"
|
|
|
+ "请根据用户需求生成专业的考试题目,包括单选题、多选题、判断题等。\n"
|
|
|
+ "每道题目应包含:题目内容、选项(如适用)、正确答案、解析。\n"
|
|
|
+ "输出格式应为结构化的 JSON。"
|
|
|
+ )
|
|
|
+
|
|
|
+ messages = [
|
|
|
+ {"role": "system", "content": system_content},
|
|
|
+ {"role": "user", "content": message},
|
|
|
+ ]
|
|
|
+
|
|
|
+ response_text = await qwen_service.chat(messages)
|
|
|
+
|
|
|
+ if data.exam_name:
|
|
|
+ db.query(AIConversation).filter(AIConversation.id == conv_id).update(
|
|
|
+ {"exam_name": data.exam_name, "updated_at": int(time.time())}
|
|
|
+ )
|
|
|
+ db.commit()
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"[send_deepseek_message] 考试工坊异常: {e}")
|
|
|
+ response_text = f"处理失败: {str(e)}"
|
|
|
+
|
|
|
+ else:
|
|
|
+ return {"statusCode": 400, "msg": f"不支持的业务类型: {data.business_type}"}
|
|
|
+
|
|
|
+ return {
|
|
|
+ "statusCode": 200,
|
|
|
+ "msg": "success",
|
|
|
+ "data": {
|
|
|
+ "conversation_id": conv_id,
|
|
|
+ "response": response_text,
|
|
|
+ "user_id": user.user_id,
|
|
|
+ "business_type": data.business_type,
|
|
|
+ },
|
|
|
+ }
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"[send_deepseek_message] 异常: {e}")
|
|
|
+ return {"statusCode": 500, "msg": f"处理失败: {str(e)}"}
|
|
|
+
|
|
|
+
|
|
|
+@router.get("/get_history_record")
|
|
|
+async def get_history_record(request: Request, db: Session = Depends(get_db)):
|
|
|
+ """获取对话历史记录列表"""
|
|
|
+ user = request.state.user
|
|
|
+ if not user:
|
|
|
+ return {"statusCode": 401, "msg": "未授权"}
|
|
|
+ conversations = (
|
|
|
+ db.query(AIConversation)
|
|
|
+ .filter(
|
|
|
+ AIConversation.user_id == user.user_id,
|
|
|
+ AIConversation.is_deleted == 0,
|
|
|
+ )
|
|
|
+ .order_by(AIConversation.created_at.desc())
|
|
|
+ .limit(50)
|
|
|
+ .all()
|
|
|
+ )
|
|
|
+ return {
|
|
|
+ "statusCode": 200,
|
|
|
+ "msg": "success",
|
|
|
+ "data": [
|
|
|
+ {
|
|
|
+ "id": conv.id,
|
|
|
+ "content": (conv.content or "")[:50]
|
|
|
+ + ("..." if len(conv.content or "") > 50 else ""),
|
|
|
+ "business_type": conv.business_type,
|
|
|
+ "exam_name": conv.exam_name,
|
|
|
+ "created_at": conv.created_at,
|
|
|
+ }
|
|
|
+ for conv in conversations
|
|
|
+ ],
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+class DeleteConversationRequest(BaseModel):
|
|
|
+ ai_conversation_id: int
|
|
|
+
|
|
|
+
|
|
|
+@router.post("/delete_conversation")
|
|
|
+async def delete_conversation(
|
|
|
+ request: Request, data: DeleteConversationRequest, db: Session = Depends(get_db)
|
|
|
+):
|
|
|
+ """
|
|
|
+ 删除对话(软删除)
|
|
|
+ 同时软删除对话记录和所有关联的消息
|
|
|
+ """
|
|
|
+ user = request.state.user
|
|
|
+ if not user:
|
|
|
+ return {"statusCode": 401, "msg": "未授权"}
|
|
|
+
|
|
|
+ db.query(AIConversation).filter(
|
|
|
+ AIConversation.id == data.ai_conversation_id,
|
|
|
+ AIConversation.user_id == user.user_id,
|
|
|
+ ).update({"is_deleted": 1, "updated_at": int(time.time())})
|
|
|
+
|
|
|
+ db.query(AIMessage).filter(
|
|
|
+ AIMessage.ai_conversation_id == data.ai_conversation_id
|
|
|
+ ).update({"is_deleted": 1, "updated_at": int(time.time())})
|
|
|
+
|
|
|
+ db.commit()
|
|
|
+ return {"statusCode": 200, "msg": "删除成功"}
|
|
|
+
|
|
|
+
|
|
|
+class DeleteHistoryRequest(BaseModel):
|
|
|
+ ai_conversation_id: int
|
|
|
+
|
|
|
+
|
|
|
+@router.post("/delete_history_record")
|
|
|
+async def delete_history_record(
|
|
|
+ request: Request, data: DeleteHistoryRequest, db: Session = Depends(get_db)
|
|
|
+):
|
|
|
+ """删除历史记录(软删除)"""
|
|
|
+ user = request.state.user
|
|
|
+ if not user:
|
|
|
+ return {"statusCode": 401, "msg": "未授权"}
|
|
|
+ db.query(AIConversation).filter(
|
|
|
+ AIConversation.id == data.ai_conversation_id,
|
|
|
+ AIConversation.user_id == user.user_id,
|
|
|
+ ).update({"is_deleted": 1, "updated_at": int(time.time())})
|
|
|
+ db.commit()
|
|
|
+ return {"statusCode": 200, "msg": "删除成功"}
|
|
|
+
|
|
|
+
|
|
|
+# ─────────────────────────────────────────────────────────────────────────
|
|
|
+# 流式接口 /stream/chat(无 DB,意图识别 + RAG)
|
|
|
+# ─────────────────────────────────────────────────────────────────────────
|
|
|
+
|
|
|
+class StreamChatRequest(BaseModel):
|
|
|
+ message: str
|
|
|
+ model: str = ""
|
|
|
+
|
|
|
+
|
|
|
+@router.post("/stream/chat")
|
|
|
+async def stream_chat(request: Request, data: StreamChatRequest):
|
|
|
+ """流式聊天(SSE,不写 DB)"""
|
|
|
+ message = data.message.strip()
|
|
|
+ if not message:
|
|
|
+ return JSONResponse(content={"statusCode": 400, "msg": "消息不能为空"})
|
|
|
+
|
|
|
+ async def event_generator():
|
|
|
+ intent_type = ""
|
|
|
+ try:
|
|
|
+ intent_result = await qwen_service.intent_recognition(message)
|
|
|
+ if isinstance(intent_result, dict):
|
|
|
+ intent_type = (
|
|
|
+ intent_result.get("intent_type") or intent_result.get("intent") or ""
|
|
|
+ ).lower()
|
|
|
+ except Exception as ie:
|
|
|
+ logger.warning(f"[stream/chat] 意图识别异常: {ie}")
|
|
|
+
|
|
|
+ rag_context = ""
|
|
|
+ if intent_type in ("query_knowledge_base", "知识库查询", "技术咨询"):
|
|
|
+ rag_context = await _rag_search(message)
|
|
|
+
|
|
|
+ # 使用prompt加载器加载最终回答prompt
|
|
|
+ system_content = load_prompt(
|
|
|
+ "final_answer",
|
|
|
+ userMessage=message,
|
|
|
+ contextJSON=rag_context if rag_context else "暂无相关知识库内容"
|
|
|
+ )
|
|
|
+
|
|
|
+ messages = [
|
|
|
+ {"role": "user", "content": system_content},
|
|
|
+ ]
|
|
|
+
|
|
|
+ try:
|
|
|
+ async for chunk in qwen_service.stream_chat(messages):
|
|
|
+ yield f"data: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"[stream/chat] 流式输出异常: {e}")
|
|
|
+ yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
|
|
+ finally:
|
|
|
+ yield "data: [DONE]\n\n"
|
|
|
+
|
|
|
+ return StreamingResponse(event_generator(), media_type="text/event-stream")
|
|
|
+
|
|
|
+
|
|
|
+# ─────────────────────────────────────────────────────────────────────────
|
|
|
+# 流式接口 /stream/chat-with-db(前端主聊天接口)
|
|
|
+# ─────────────────────────────────────────────────────────────────────────
|
|
|
+
|
|
|
+class StreamChatWithDBRequest(BaseModel):
|
|
|
+ message: str
|
|
|
+ ai_conversation_id: int = 0
|
|
|
+ business_type: int = 0
|
|
|
+ exam_name: str = ""
|
|
|
+ ai_message_id: int = 0
|
|
|
+ online_search_content: str = ""
|
|
|
+
|
|
|
+
|
|
|
+@router.post("/stream/chat-with-db")
|
|
|
+async def stream_chat_with_db(request: Request, data: StreamChatWithDBRequest):
|
|
|
+ """
|
|
|
+ 带 DB 操作的流式聊天(SSE)
|
|
|
+ 流程:
|
|
|
+ 1. 创建/获取对话
|
|
|
+ 2. 插入用户消息和 AI 占位消息
|
|
|
+ 3. 发送 initial 事件
|
|
|
+ 4. RAG 检索
|
|
|
+ 5. 构建历史上下文
|
|
|
+ 6. 流式输出
|
|
|
+ 7. 更新 AI 消息内容
|
|
|
+ """
|
|
|
+ user = request.state.user
|
|
|
+ if not user:
|
|
|
+ return JSONResponse(content={"statusCode": 401, "msg": "未授权"})
|
|
|
+
|
|
|
+ message = data.message.strip()
|
|
|
+ if not message:
|
|
|
+ return JSONResponse(content={"statusCode": 400, "msg": "消息不能为空"})
|
|
|
+
|
|
|
+ async def event_generator():
|
|
|
+ db = SessionLocal()
|
|
|
+ try:
|
|
|
+ # 1. 创建或获取对话
|
|
|
+ if data.ai_conversation_id == 0:
|
|
|
+ conversation = AIConversation(
|
|
|
+ user_id=user.user_id,
|
|
|
+ content=message[:100],
|
|
|
+ business_type=data.business_type,
|
|
|
+ exam_name=data.exam_name,
|
|
|
+ created_at=int(time.time()),
|
|
|
+ updated_at=int(time.time()),
|
|
|
+ is_deleted=0,
|
|
|
+ )
|
|
|
+ db.add(conversation)
|
|
|
+ db.commit()
|
|
|
+ db.refresh(conversation)
|
|
|
+ conv_id = conversation.id
|
|
|
+ else:
|
|
|
+ conv_id = data.ai_conversation_id
|
|
|
+
|
|
|
+ # 2. 插入用户消息
|
|
|
+ user_msg = AIMessage(
|
|
|
+ ai_conversation_id=conv_id,
|
|
|
+ user_id=user.user_id,
|
|
|
+ type="user",
|
|
|
+ content=message,
|
|
|
+ created_at=int(time.time()),
|
|
|
+ updated_at=int(time.time()),
|
|
|
+ is_deleted=0,
|
|
|
+ )
|
|
|
+ db.add(user_msg)
|
|
|
+ db.commit()
|
|
|
+ db.refresh(user_msg)
|
|
|
+
|
|
|
+ # 3. 插入 AI 占位消息
|
|
|
+ ai_msg = AIMessage(
|
|
|
+ ai_conversation_id=conv_id,
|
|
|
+ user_id=user.user_id,
|
|
|
+ type="ai",
|
|
|
+ content="",
|
|
|
+ prev_user_id=user_msg.id,
|
|
|
+ created_at=int(time.time()),
|
|
|
+ updated_at=int(time.time()),
|
|
|
+ is_deleted=0,
|
|
|
+ )
|
|
|
+ db.add(ai_msg)
|
|
|
+ db.commit()
|
|
|
+ db.refresh(ai_msg)
|
|
|
+
|
|
|
+ # 4. 发送 initial 事件
|
|
|
+ yield f"data: {json.dumps({'type': 'initial', 'ai_conversation_id': conv_id, 'ai_message_id': ai_msg.id}, ensure_ascii=False)}\n\n"
|
|
|
+
|
|
|
+ # 5. RAG 检索
|
|
|
+ rag_context = await _rag_search(message, top_k=10)
|
|
|
+
|
|
|
+ # 6. 获取历史上下文(最近 4 条,2 轮对话)
|
|
|
+ history_msgs = (
|
|
|
+ db.query(AIMessage)
|
|
|
+ .filter(
|
|
|
+ AIMessage.ai_conversation_id == conv_id,
|
|
|
+ AIMessage.id < ai_msg.id,
|
|
|
+ AIMessage.is_deleted == 0,
|
|
|
+ )
|
|
|
+ .order_by(AIMessage.updated_at.desc())
|
|
|
+ .limit(4)
|
|
|
+ .all()
|
|
|
+ )
|
|
|
+ history_msgs.reverse()
|
|
|
+
|
|
|
+ history_context = ""
|
|
|
+ for msg in history_msgs:
|
|
|
+ role = "用户" if msg.type == "user" else "助手"
|
|
|
+ history_context += f"{role}: {msg.content}\n\n"
|
|
|
+
|
|
|
+ # 7. 构建完整 prompt
|
|
|
+ # 构建上下文JSON
|
|
|
+ context_parts = []
|
|
|
+ if rag_context:
|
|
|
+ context_parts.append(f"知识库内容:\n{rag_context}")
|
|
|
+ if data.online_search_content:
|
|
|
+ context_parts.append(f"联网搜索结果:\n{data.online_search_content}")
|
|
|
+
|
|
|
+ context_json = "\n\n".join(context_parts) if context_parts else "暂无相关知识库内容"
|
|
|
+
|
|
|
+ # 使用prompt加载器加载最终回答prompt
|
|
|
+ system_content = load_prompt(
|
|
|
+ "final_answer",
|
|
|
+ userMessage=message,
|
|
|
+ contextJSON=context_json,
|
|
|
+ historyContext=history_context if history_context else ""
|
|
|
+ )
|
|
|
+
|
|
|
+ messages = [
|
|
|
+ {"role": "user", "content": system_content},
|
|
|
+ ]
|
|
|
+
|
|
|
+ # 8. 流式输出并收集完整回复
|
|
|
+ full_response = ""
|
|
|
+ try:
|
|
|
+ async for chunk in qwen_service.stream_chat(messages):
|
|
|
+ escaped_chunk = chunk.replace("\n", "\\n")
|
|
|
+ full_response += chunk
|
|
|
+ yield f"data: {escaped_chunk}\n\n"
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"[stream/chat-with-db] 流式输出异常: {e}")
|
|
|
+ yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
|
|
+
|
|
|
+ # 9. 更新 AI 消息内容
|
|
|
+ if full_response:
|
|
|
+ db.query(AIMessage).filter(AIMessage.id == ai_msg.id).update(
|
|
|
+ {"content": full_response, "updated_at": int(time.time())}
|
|
|
+ )
|
|
|
+ db.commit()
|
|
|
+
|
|
|
+ # 10. 结束标记
|
|
|
+ yield "data: [DONE]\n\n"
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"[stream/chat-with-db] 处理异常: {e}")
|
|
|
+ yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
|
|
|
+ finally:
|
|
|
+ db.close()
|
|
|
+
|
|
|
+ return StreamingResponse(event_generator(), media_type="text/event-stream")
|
|
|
+
|
|
|
+
|
|
|
+# ─────────────────────────────────────────────────────────────────────────
|
|
|
+# 猜你想问
|
|
|
+# ─────────────────────────────────────────────────────────────────────────
|
|
|
+
|
|
|
+class GuessYouWantRequest(BaseModel):
|
|
|
+ ai_message_id: int
|
|
|
+
|
|
|
+
|
|
|
+@router.post("/guess_you_want")
|
|
|
+async def guess_you_want(
|
|
|
+ request: Request,
|
|
|
+ data: GuessYouWantRequest,
|
|
|
+ db: Session = Depends(get_db),
|
|
|
+):
|
|
|
+ """生成"猜你想问"的3个关联问题,保存到 AIMessage.guess_you_want"""
|
|
|
+ user = request.state.user
|
|
|
+ if not user:
|
|
|
+ return {"statusCode": 401, "msg": "未授权"}
|
|
|
+
|
|
|
+ try:
|
|
|
+ ai_msg = (
|
|
|
+ db.query(AIMessage)
|
|
|
+ .filter(AIMessage.id == data.ai_message_id, AIMessage.is_deleted == 0)
|
|
|
+ .first()
|
|
|
+ )
|
|
|
+ if not ai_msg:
|
|
|
+ return {"statusCode": 404, "msg": "消息不存在"}
|
|
|
+
|
|
|
+ # 使用prompt加载器加载猜你想问prompt
|
|
|
+ system_content = load_prompt(
|
|
|
+ "guess_questions",
|
|
|
+ currentContent=ai_msg.content[:500]
|
|
|
+ )
|
|
|
+
|
|
|
+ messages = [
|
|
|
+ {"role": "user", "content": system_content},
|
|
|
+ ]
|
|
|
+
|
|
|
+ response = await qwen_service.chat(messages)
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 尝试从响应中提取 JSON
|
|
|
+ import re
|
|
|
+ json_match = re.search(r'\{[^{}]*"questions"[^{}]*\}', response, re.DOTALL)
|
|
|
+ if json_match:
|
|
|
+ response_json = json.loads(json_match.group())
|
|
|
+ else:
|
|
|
+ response_json = json.loads(response)
|
|
|
+ questions = response_json.get("questions", [])
|
|
|
+ except Exception:
|
|
|
+ lines = [l.strip() for l in response.split("\n") if l.strip()]
|
|
|
+ questions = []
|
|
|
+ for line in lines:
|
|
|
+ clean = line.lstrip("0123456789.-、 ").strip()
|
|
|
+ if clean and len(clean) > 5:
|
|
|
+ questions.append(clean)
|
|
|
+ if not questions:
|
|
|
+ questions = ["该话题的具体应用场景?", "有哪些注意事项?", "相关案例分析?"]
|
|
|
+
|
|
|
+ questions = questions[:3]
|
|
|
+ while len(questions) < 3:
|
|
|
+ questions.append("更多相关问题")
|
|
|
+
|
|
|
+ guess_json = json.dumps({"questions": questions}, ensure_ascii=False)
|
|
|
+
|
|
|
+ db.query(AIMessage).filter(AIMessage.id == data.ai_message_id).update(
|
|
|
+ {"guess_you_want": guess_json, "updated_at": int(time.time())}
|
|
|
+ )
|
|
|
+ db.commit()
|
|
|
+
|
|
|
+ return {
|
|
|
+ "statusCode": 200,
|
|
|
+ "msg": "success",
|
|
|
+ "data": {"ai_message_id": data.ai_message_id, "questions": questions},
|
|
|
+ }
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"[guess_you_want] 处理异常: {e}")
|
|
|
+ return {"statusCode": 500, "msg": f"处理失败: {str(e)}"}
|
|
|
+
|
|
|
+
|
|
|
+# ─────────────────────────────────────────────────────────────────────────
|
|
|
+# 在线搜索(Dify 工作流集成)
|
|
|
+# ─────────────────────────────────────────────────────────────────────────
|
|
|
+
|
|
|
+@router.get("/online_search")
|
|
|
+async def online_search(question: str, request: Request, db: Session = Depends(get_db)):
|
|
|
+ """
|
|
|
+ 在线搜索
|
|
|
+ 流程:Qwen 提炼关键词 → Dify 工作流 → 返回摘要
|
|
|
+ """
|
|
|
+ user = request.state.user
|
|
|
+ if not user:
|
|
|
+ return {"statusCode": 401, "msg": "未授权"}
|
|
|
+
|
|
|
+ try:
|
|
|
+ keywords = await qwen_service.extract_keywords(question)
|
|
|
+
|
|
|
+ dify_config = getattr(settings, "dify", None)
|
|
|
+ if not dify_config or not getattr(dify_config, "workflow_url", None):
|
|
|
+ return {"statusCode": 500, "msg": "Dify 配置未设置"}
|
|
|
+
|
|
|
+ headers = {
|
|
|
+ "Authorization": f"Bearer {dify_config.auth_token}",
|
|
|
+ "Content-Type": "application/json",
|
|
|
+ }
|
|
|
+ payload = {
|
|
|
+ "workflow_id": dify_config.workflow_id,
|
|
|
+ "inputs": {
|
|
|
+ "keywords": keywords,
|
|
|
+ "num": 5, # 搜索结果数量
|
|
|
+ "max_text_len": 4000 # 最大文本长度
|
|
|
+ },
|
|
|
+ "response_mode": "blocking",
|
|
|
+ "user": getattr(user, "account", str(user.user_id)),
|
|
|
+ }
|
|
|
+
|
|
|
+ async with httpx.AsyncClient(timeout=30.0) as client:
|
|
|
+ resp = await client.post(dify_config.workflow_url, headers=headers, json=payload)
|
|
|
+ if resp.status_code != 200:
|
|
|
+ logger.error(f"[online_search] Dify 调用失败: {resp.status_code}, 响应: {resp.text}")
|
|
|
+ return {"statusCode": 500, "msg": f"搜索服务异常: {resp.status_code}"}
|
|
|
+ result = resp.json()
|
|
|
+ search_text = result.get("data", {}).get("outputs", {}).get("text", "")
|
|
|
+
|
|
|
+ return {
|
|
|
+ "statusCode": 200,
|
|
|
+ "msg": "success",
|
|
|
+ "data": {"keywords": keywords, "result": search_text},
|
|
|
+ }
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"[online_search] 处理异常: {e}")
|
|
|
+ return {"statusCode": 500, "msg": f"搜索失败: {str(e)}"}
|
|
|
+
|
|
|
+
|
|
|
+class SaveOnlineSearchResultRequest(BaseModel):
|
|
|
+ ai_message_id: int
|
|
|
+ search_result: str
|
|
|
+
|
|
|
+
|
|
|
+@router.post("/save_online_search_result")
|
|
|
+async def save_online_search_result(
|
|
|
+ request: Request,
|
|
|
+ data: SaveOnlineSearchResultRequest,
|
|
|
+ db: Session = Depends(get_db),
|
|
|
+):
|
|
|
+ """保存联网搜索结果到 AIMessage.search_source"""
|
|
|
+ user = request.state.user
|
|
|
+ if not user:
|
|
|
+ return {"statusCode": 401, "msg": "未授权"}
|
|
|
+
|
|
|
+ try:
|
|
|
+ db.query(AIMessage).filter(AIMessage.id == data.ai_message_id).update(
|
|
|
+ {"search_source": data.search_result, "updated_at": int(time.time())}
|
|
|
+ )
|
|
|
+ db.commit()
|
|
|
+ return {"statusCode": 200, "msg": "保存成功"}
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"[save_online_search_result] 处理异常: {e}")
|
|
|
+ return {"statusCode": 500, "msg": f"保存失败: {str(e)}"}
|
|
|
+
|
|
|
+
|
|
|
+# ─────────────────────────────────────────────────────────────────────────
|
|
|
+# 意图识别独立接口
|
|
|
+# ─────────────────────────────────────────────────────────────────────────
|
|
|
+
|
|
|
+class IntentRecognitionRequest(BaseModel):
|
|
|
+ message: str
|
|
|
+ save_to_db: bool = False
|
|
|
+ ai_conversation_id: int = 0
|
|
|
+
|
|
|
+
|
|
|
+@router.post("/intent_recognition")
|
|
|
+async def intent_recognition(
|
|
|
+ request: Request,
|
|
|
+ data: IntentRecognitionRequest,
|
|
|
+ db: Session = Depends(get_db),
|
|
|
+):
|
|
|
+ """独立意图识别接口;若为 greeting/faq 且 save_to_db=True 则直接存 DB"""
|
|
|
+ user = request.state.user
|
|
|
+ if not user:
|
|
|
+ return {"statusCode": 401, "msg": "未授权"}
|
|
|
+
|
|
|
+ try:
|
|
|
+ intent_result = await qwen_service.intent_recognition(data.message)
|
|
|
+ intent_type = ""
|
|
|
+ response_text = ""
|
|
|
+ if isinstance(intent_result, dict):
|
|
|
+ intent_type = (
|
|
|
+ intent_result.get("intent_type") or intent_result.get("intent") or ""
|
|
|
+ ).lower()
|
|
|
+ response_text = intent_result.get("response", "")
|
|
|
+
|
|
|
+ if data.save_to_db and intent_type in ("greeting", "问候", "faq", "常见问题"):
|
|
|
+ if data.ai_conversation_id == 0:
|
|
|
+ conversation = AIConversation(
|
|
|
+ user_id=user.user_id,
|
|
|
+ content=data.message[:100],
|
|
|
+ business_type=0,
|
|
|
+ created_at=int(time.time()),
|
|
|
+ updated_at=int(time.time()),
|
|
|
+ is_deleted=0,
|
|
|
+ )
|
|
|
+ db.add(conversation)
|
|
|
+ db.commit()
|
|
|
+ db.refresh(conversation)
|
|
|
+ conv_id = conversation.id
|
|
|
+ else:
|
|
|
+ conv_id = data.ai_conversation_id
|
|
|
+
|
|
|
+ user_msg = AIMessage(
|
|
|
+ ai_conversation_id=conv_id,
|
|
|
+ user_id=user.user_id,
|
|
|
+ type="user",
|
|
|
+ content=data.message,
|
|
|
+ created_at=int(time.time()),
|
|
|
+ updated_at=int(time.time()),
|
|
|
+ is_deleted=0,
|
|
|
+ )
|
|
|
+ db.add(user_msg)
|
|
|
+ db.commit()
|
|
|
+
|
|
|
+ ai_msg = AIMessage(
|
|
|
+ ai_conversation_id=conv_id,
|
|
|
+ user_id=user.user_id,
|
|
|
+ type="ai",
|
|
|
+ content=response_text,
|
|
|
+ prev_user_id=user_msg.id,
|
|
|
+ created_at=int(time.time()),
|
|
|
+ updated_at=int(time.time()),
|
|
|
+ is_deleted=0,
|
|
|
+ )
|
|
|
+ db.add(ai_msg)
|
|
|
+ db.commit()
|
|
|
+ db.refresh(ai_msg)
|
|
|
+
|
|
|
+ return {
|
|
|
+ "statusCode": 200,
|
|
|
+ "msg": "success",
|
|
|
+ "data": {
|
|
|
+ "intent_type": intent_type,
|
|
|
+ "response": response_text,
|
|
|
+ "ai_conversation_id": conv_id,
|
|
|
+ "ai_message_id": ai_msg.id,
|
|
|
+ "saved_to_db": True,
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ return {
|
|
|
+ "statusCode": 200,
|
|
|
+ "msg": "success",
|
|
|
+ "data": {
|
|
|
+ "intent_type": intent_type,
|
|
|
+ "response": response_text,
|
|
|
+ "saved_to_db": False,
|
|
|
+ },
|
|
|
+ }
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"[intent_recognition] 处理异常: {e}")
|
|
|
+ return {"statusCode": 500, "msg": f"处理失败: {str(e)}"}
|
|
|
+
|
|
|
+
|
|
|
+# ─────────────────────────────────────────────────────────────────────────
|
|
|
+# 获取用户推荐问题(模糊查询 QA / RecommendQuestion 表)
|
|
|
+# ─────────────────────────────────────────────────────────────────────────
|
|
|
+
|
|
|
+@router.get("/get_user_recommend_question")
|
|
|
+async def get_user_recommend_question(
|
|
|
+ keyword: str = "",
|
|
|
+ limit: int = 10,
|
|
|
+ db: Session = Depends(get_db),
|
|
|
+):
|
|
|
+ """获取推荐问题(支持模糊查询)"""
|
|
|
+ try:
|
|
|
+ query = db.query(RecommendQuestion).filter(RecommendQuestion.is_deleted == 0)
|
|
|
+ if keyword:
|
|
|
+ query = query.filter(RecommendQuestion.question.like(f"%{keyword}%"))
|
|
|
+ questions = query.order_by(RecommendQuestion.id.desc()).limit(limit).all()
|
|
|
+
|
|
|
+ return {
|
|
|
+ "statusCode": 200,
|
|
|
+ "msg": "success",
|
|
|
+ "data": [
|
|
|
+ {"id": q.id, "question": q.question, "created_at": q.created_at}
|
|
|
+ for q in questions
|
|
|
+ ],
|
|
|
+ }
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"[get_user_recommend_question] 处理异常: {e}")
|
|
|
+ return {"statusCode": 500, "msg": f"查询失败: {str(e)}"}
|
|
|
+
|
|
|
+
|
|
|
+# ─────────────────────────────────────────────────────────────────────────
|
|
|
+# PPT 大纲 / 文档编辑保存
|
|
|
+# ─────────────────────────────────────────────────────────────────────────
|
|
|
+
|
|
|
+class SavePPTOutlineRequest(BaseModel):
|
|
|
+ ai_message_id: int
|
|
|
+ content: str
|
|
|
+
|
|
|
+
|
|
|
+@router.post("/save_ppt_outline")
|
|
|
+async def save_ppt_outline(
|
|
|
+ request: Request,
|
|
|
+ data: SavePPTOutlineRequest,
|
|
|
+ db: Session = Depends(get_db),
|
|
|
+):
|
|
|
+ """更新 AIMessage.content 保存 PPT 大纲内容"""
|
|
|
+ user = request.state.user
|
|
|
+ if not user:
|
|
|
+ return {"statusCode": 401, "msg": "未授权"}
|
|
|
+
|
|
|
+ try:
|
|
|
+ db.query(AIMessage).filter(AIMessage.id == data.ai_message_id).update(
|
|
|
+ {"content": data.content, "updated_at": int(time.time())}
|
|
|
+ )
|
|
|
+ db.commit()
|
|
|
+ return {"statusCode": 200, "msg": "保存成功"}
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"[save_ppt_outline] 处理异常: {e}")
|
|
|
+ return {"statusCode": 500, "msg": f"保存失败: {str(e)}"}
|
|
|
+
|
|
|
+
|
|
|
+class SaveEditDocumentRequest(BaseModel):
|
|
|
+ ai_message_id: int
|
|
|
+ content: str
|
|
|
+
|
|
|
+
|
|
|
+@router.post("/save_edit_document")
|
|
|
+async def save_edit_document(
|
|
|
+ request: Request,
|
|
|
+ data: SaveEditDocumentRequest,
|
|
|
+ db: Session = Depends(get_db),
|
|
|
+):
|
|
|
+ """更新 ai 类型 AIMessage.content(AI写作编辑保存)"""
|
|
|
+ user = request.state.user
|
|
|
+ if not user:
|
|
|
+ return {"statusCode": 401, "msg": "未授权"}
|
|
|
+
|
|
|
+ try:
|
|
|
+ db.query(AIMessage).filter(
|
|
|
+ AIMessage.id == data.ai_message_id,
|
|
|
+ AIMessage.type == "ai",
|
|
|
+ ).update({"content": data.content, "updated_at": int(time.time())})
|
|
|
+ db.commit()
|
|
|
+ return {"statusCode": 200, "msg": "保存成功"}
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"[save_edit_document] 处理异常: {e}")
|
|
|
+ return {"statusCode": 500, "msg": f"保存失败: {str(e)}"}
|