chat.py 41 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116
  1. from fastapi import APIRouter, Depends, Request
  2. from fastapi.responses import StreamingResponse, JSONResponse
  3. from sqlalchemy.orm import Session
  4. from pydantic import BaseModel
  5. from typing import Optional
  6. from database import get_db, SessionLocal
  7. from models.chat import AIConversation, AIMessage
  8. from models.total import RecommendQuestion
  9. from utils.config import settings
  10. from utils.logger import logger
  11. from services.qwen_service import qwen_service
  12. from utils.prompt_loader import load_prompt
  13. import time
  14. import json
  15. import httpx
  16. router = APIRouter()
  17. def _build_conversation_preview(content: str, limit: int = 50) -> str:
  18. content = (content or "").strip()
  19. if len(content) <= limit:
  20. return content
  21. return content[:limit] + "..."
  22. def _to_frontend_timestamp(timestamp: Optional[int]) -> Optional[int]:
  23. if not timestamp:
  24. return None
  25. return timestamp if timestamp >= 10**12 else timestamp * 1000
  26. def _build_conversation_title(conversation: AIConversation) -> str:
  27. if conversation.business_type == 3 and (conversation.exam_name or "").strip():
  28. return conversation.exam_name.strip()
  29. return _build_conversation_preview(conversation.content or "", limit=30)
  30. def _refresh_conversation_snapshot(db: Session, conversation_id: int, user_id: int) -> None:
  31. latest_message = (
  32. db.query(AIMessage)
  33. .filter(
  34. AIMessage.ai_conversation_id == conversation_id,
  35. AIMessage.user_id == user_id,
  36. AIMessage.is_deleted == 0,
  37. )
  38. .order_by(AIMessage.id.desc())
  39. .first()
  40. )
  41. if not latest_message:
  42. db.query(AIConversation).filter(
  43. AIConversation.id == conversation_id,
  44. AIConversation.user_id == user_id,
  45. ).update({"is_deleted": 1, "updated_at": int(time.time())})
  46. return
  47. latest_user_message = (
  48. db.query(AIMessage)
  49. .filter(
  50. AIMessage.ai_conversation_id == conversation_id,
  51. AIMessage.user_id == user_id,
  52. AIMessage.type == "user",
  53. AIMessage.is_deleted == 0,
  54. )
  55. .order_by(AIMessage.id.desc())
  56. .first()
  57. )
  58. preview_source = (
  59. latest_user_message.content
  60. if latest_user_message and latest_user_message.content
  61. else latest_message.content
  62. )
  63. preview_content = _build_conversation_preview(preview_source or "", limit=100)
  64. db.query(AIConversation).filter(
  65. AIConversation.id == conversation_id,
  66. AIConversation.user_id == user_id,
  67. ).update(
  68. {
  69. "content": preview_content or " ",
  70. "updated_at": int(time.time()),
  71. }
  72. )
  73. # ─────────────────────────────────────────────────────────────────────────
  74. # 辅助函数
  75. # ─────────────────────────────────────────────────────────────────────────
  76. async def _rag_search(message: str, top_k: int = 5) -> str:
  77. """调用 search API 做 RAG 检索,返回上下文文本"""
  78. try:
  79. search_cfg = getattr(settings, 'search', None)
  80. if not search_cfg or not hasattr(search_cfg, 'api_url'):
  81. return ""
  82. search_url = search_cfg.api_url
  83. if not search_url:
  84. return ""
  85. async with httpx.AsyncClient(timeout=10.0) as client:
  86. resp = await client.post(
  87. search_url,
  88. json={"query": message, "n_results": top_k},
  89. )
  90. if resp.status_code == 200:
  91. data = resp.json()
  92. docs = data.get("results") or data.get("documents") or []
  93. return "\n\n".join(
  94. d.get("content") or d.get("text") or str(d)
  95. for d in docs[:top_k]
  96. if d.get("content") or d.get("text")
  97. )
  98. except Exception as e:
  99. logger.warning(f"[RAG] 检索失败(可忽略): {e}")
  100. return ""
  101. def _build_history_messages(conv_id: int, limit: int = 10) -> list:
  102. """从数据库读取最近对话历史,构建 messages 列表"""
  103. db = SessionLocal()
  104. try:
  105. msgs = (
  106. db.query(AIMessage)
  107. .filter(AIMessage.ai_conversation_id == conv_id, AIMessage.is_deleted == 0)
  108. .order_by(AIMessage.id.desc())
  109. .limit(limit)
  110. .all()
  111. )
  112. msgs.reverse()
  113. result = []
  114. for m in msgs:
  115. role = "user" if m.type == "user" else "assistant"
  116. if m.content:
  117. result.append({"role": role, "content": m.content})
  118. return result
  119. finally:
  120. db.close()
  121. # ─────────────────────────────────────────────────────────────────────────
  122. # 非流式接口
  123. # ─────────────────────────────────────────────────────────────────────────
  124. class SendMessageRequest(BaseModel):
  125. message: str
  126. conversation_id: Optional[int] = None
  127. business_type: int = 0 # 0=AI问答, 1=PPT大纲, 2=AI写作, 3=考试工坊
  128. exam_name: str = ""
  129. ai_message_id: int = 0
  130. @router.post("/send_deepseek_message")
  131. async def send_deepseek_message(
  132. request: Request,
  133. data: SendMessageRequest,
  134. db: Session = Depends(get_db),
  135. ):
  136. """
  137. 发送消息(非流式)
  138. 支持多种业务类型:
  139. - 0: AI问答(意图识别 + RAG)
  140. - 1: PPT大纲生成
  141. - 2: AI写作
  142. - 3: 考试工坊
  143. """
  144. user = request.state.user
  145. if not user:
  146. return {"statusCode": 401, "msg": "未授权"}
  147. try:
  148. message = data.message.strip()
  149. if not message:
  150. return {"statusCode": 400, "msg": "消息不能为空"}
  151. # 创建或获取对话
  152. if not data.conversation_id:
  153. conversation = AIConversation(
  154. user_id=user.user_id,
  155. content=message[:100],
  156. business_type=data.business_type,
  157. exam_name=data.exam_name if data.business_type == 3 else "",
  158. created_at=int(time.time()),
  159. updated_at=int(time.time()),
  160. is_deleted=0,
  161. )
  162. db.add(conversation)
  163. db.commit()
  164. db.refresh(conversation)
  165. conv_id = conversation.id
  166. else:
  167. conv_id = data.conversation_id
  168. db.query(AIConversation).filter(
  169. AIConversation.id == conv_id,
  170. AIConversation.user_id == user.user_id,
  171. AIConversation.is_deleted == 0,
  172. ).update({
  173. "content": message[:100],
  174. "business_type": data.business_type,
  175. "exam_name": data.exam_name if data.business_type == 3 else "",
  176. "updated_at": int(time.time()),
  177. })
  178. db.commit()
  179. response_text = ""
  180. if data.business_type == 0:
  181. # AI问答:意图识别 + RAG
  182. try:
  183. intent_result = await qwen_service.intent_recognition(message)
  184. intent_type = ""
  185. if isinstance(intent_result, dict):
  186. intent_type = (
  187. intent_result.get("intent_type") or intent_result.get("intent") or ""
  188. ).lower()
  189. rag_context = ""
  190. if intent_type in ("query_knowledge_base", "知识库查询", "技术咨询"):
  191. rag_context = await _rag_search(message, top_k=10)
  192. # 使用prompt加载器加载最终回答prompt
  193. system_content = load_prompt(
  194. "final_answer",
  195. userMessage=message,
  196. contextJSON=rag_context if rag_context else "暂无相关知识库内容"
  197. )
  198. messages = [
  199. {"role": "user", "content": system_content},
  200. ]
  201. qwen_response = await qwen_service.chat(messages)
  202. try:
  203. if isinstance(qwen_response, str) and qwen_response.strip().startswith("{"):
  204. response_json = json.loads(qwen_response)
  205. response_text = response_json.get("natural_language_answer", qwen_response)
  206. else:
  207. response_text = qwen_response
  208. except Exception:
  209. response_text = qwen_response
  210. except Exception as e:
  211. logger.error(f"[send_deepseek_message] AI问答异常: {e}")
  212. response_text = f"处理失败: {str(e)}"
  213. elif data.business_type == 1:
  214. # PPT大纲生成
  215. try:
  216. rag_context = await _rag_search(message, top_k=10)
  217. # 使用prompt加载器加载PPT大纲生成prompt
  218. system_content = load_prompt(
  219. "ppt_outline",
  220. userMessage=message,
  221. contextJSON=rag_context if rag_context else "暂无相关知识库内容"
  222. )
  223. messages = [
  224. {"role": "user", "content": system_content},
  225. ]
  226. response_text = await qwen_service.chat(messages)
  227. except Exception as e:
  228. logger.error(f"[send_deepseek_message] PPT大纲生成异常: {e}")
  229. response_text = f"处理失败: {str(e)}"
  230. elif data.business_type == 2:
  231. # AI写作
  232. try:
  233. rag_context = await _rag_search(message, top_k=10)
  234. # 使用prompt加载器加载公文写作prompt
  235. system_content = load_prompt(
  236. "document_writing",
  237. userMessage=message,
  238. contextJSON=rag_context if rag_context else "暂无相关知识库内容"
  239. )
  240. messages = [
  241. {"role": "user", "content": system_content},
  242. ]
  243. response_text = await qwen_service.chat(messages)
  244. except Exception as e:
  245. logger.error(f"[send_deepseek_message] AI写作异常: {e}")
  246. response_text = f"处理失败: {str(e)}"
  247. elif data.business_type == 3:
  248. # 考试工坊:生成题目
  249. try:
  250. system_content = (
  251. "你是一个专业的考试题目生成助手,专注于路桥隧轨施工安全领域。\n"
  252. "请根据用户需求生成专业的考试题目,包括单选题、多选题、判断题等。\n"
  253. "每道题目应包含:题目内容、选项(如适用)、正确答案、解析。\n"
  254. "输出格式应为结构化的 JSON。"
  255. )
  256. messages = [
  257. {"role": "system", "content": system_content},
  258. {"role": "user", "content": message},
  259. ]
  260. response_text = await qwen_service.chat(messages)
  261. if data.exam_name:
  262. db.query(AIConversation).filter(AIConversation.id == conv_id).update(
  263. {"exam_name": data.exam_name, "updated_at": int(time.time())}
  264. )
  265. db.commit()
  266. except Exception as e:
  267. logger.error(f"[send_deepseek_message] 考试工坊异常: {e}")
  268. response_text = f"处理失败: {str(e)}"
  269. else:
  270. return {"statusCode": 400, "msg": f"不支持的业务类型: {data.business_type}"}
  271. return {
  272. "statusCode": 200,
  273. "msg": "success",
  274. "data": {
  275. "conversation_id": conv_id,
  276. "response": response_text,
  277. "user_id": user.user_id,
  278. "business_type": data.business_type,
  279. },
  280. }
  281. except Exception as e:
  282. logger.error(f"[send_deepseek_message] 异常: {e}")
  283. return {"statusCode": 500, "msg": f"处理失败: {str(e)}"}
  284. @router.get("/get_history_record")
  285. async def get_history_record(
  286. request: Request,
  287. ai_conversation_id: int = 0,
  288. business_type: Optional[int] = None,
  289. db: Session = Depends(get_db),
  290. ):
  291. """兼容前端的历史记录查询:ai_conversation_id=0 返回对话列表,否则返回消息详情。"""
  292. user = request.state.user
  293. if not user:
  294. return {"statusCode": 401, "msg": "未授权"}
  295. if ai_conversation_id > 0:
  296. messages = (
  297. db.query(AIMessage)
  298. .filter(
  299. AIMessage.ai_conversation_id == ai_conversation_id,
  300. AIMessage.user_id == user.user_id,
  301. AIMessage.is_deleted == 0,
  302. )
  303. .order_by(AIMessage.id.asc())
  304. .all()
  305. )
  306. return {
  307. "statusCode": 200,
  308. "msg": "success",
  309. "total": len(messages),
  310. "data": [
  311. {
  312. "id": message.id,
  313. "ai_conversation_id": message.ai_conversation_id,
  314. "user_id": message.user_id,
  315. "type": message.type,
  316. "content": message.content,
  317. "user_feedback": message.user_feedback,
  318. "prev_user_id": message.prev_user_id,
  319. "search_source": message.search_source or "",
  320. "guess_you_want": message.guess_you_want or "",
  321. "created_at": _to_frontend_timestamp(message.created_at),
  322. "updated_at": _to_frontend_timestamp(message.updated_at),
  323. }
  324. for message in messages
  325. ],
  326. }
  327. conversations_query = db.query(AIConversation).filter(
  328. AIConversation.user_id == user.user_id,
  329. AIConversation.is_deleted == 0,
  330. )
  331. if business_type is not None:
  332. conversations_query = conversations_query.filter(
  333. AIConversation.business_type == business_type
  334. )
  335. total = conversations_query.count()
  336. conversations = (
  337. conversations_query
  338. .order_by(AIConversation.updated_at.desc(), AIConversation.id.desc())
  339. .limit(50)
  340. .all()
  341. )
  342. return {
  343. "statusCode": 200,
  344. "msg": "success",
  345. "total": total,
  346. "data": [
  347. {
  348. "id": conv.id,
  349. "title": _build_conversation_title(conv),
  350. "content": conv.content or "",
  351. "business_type": conv.business_type,
  352. "exam_name": conv.exam_name or "",
  353. "created_at": _to_frontend_timestamp(conv.created_at),
  354. "updated_at": _to_frontend_timestamp(conv.updated_at),
  355. }
  356. for conv in conversations
  357. ],
  358. }
  359. class DeleteConversationRequest(BaseModel):
  360. ai_conversation_id: int = 0
  361. ai_message_id: int = 0
  362. @router.post("/delete_conversation")
  363. async def delete_conversation(
  364. request: Request, data: DeleteConversationRequest, db: Session = Depends(get_db)
  365. ):
  366. """
  367. 删除对话(软删除)
  368. 同时软删除对话记录和所有关联的消息
  369. """
  370. user = request.state.user
  371. if not user:
  372. return {"statusCode": 401, "msg": "未授权"}
  373. now_ts = int(time.time())
  374. if data.ai_message_id:
  375. ai_message = (
  376. db.query(AIMessage)
  377. .filter(
  378. AIMessage.id == data.ai_message_id,
  379. AIMessage.user_id == user.user_id,
  380. AIMessage.type == "ai",
  381. AIMessage.is_deleted == 0,
  382. )
  383. .first()
  384. )
  385. if not ai_message:
  386. return {"statusCode": 404, "msg": "消息不存在"}
  387. db.query(AIMessage).filter(
  388. AIMessage.id == ai_message.id,
  389. AIMessage.user_id == user.user_id,
  390. ).update({"is_deleted": 1, "updated_at": now_ts})
  391. if ai_message.prev_user_id:
  392. db.query(AIMessage).filter(
  393. AIMessage.id == ai_message.prev_user_id,
  394. AIMessage.user_id == user.user_id,
  395. AIMessage.ai_conversation_id == ai_message.ai_conversation_id,
  396. ).update({"is_deleted": 1, "updated_at": now_ts})
  397. _refresh_conversation_snapshot(db, ai_message.ai_conversation_id, user.user_id)
  398. db.commit()
  399. return {"statusCode": 200, "msg": "删除成功"}
  400. if not data.ai_conversation_id:
  401. return {"statusCode": 400, "msg": "缺少删除参数"}
  402. db.query(AIConversation).filter(
  403. AIConversation.id == data.ai_conversation_id,
  404. AIConversation.user_id == user.user_id,
  405. ).update({"is_deleted": 1, "updated_at": now_ts})
  406. db.query(AIMessage).filter(
  407. AIMessage.ai_conversation_id == data.ai_conversation_id,
  408. AIMessage.user_id == user.user_id,
  409. ).update({"is_deleted": 1, "updated_at": now_ts})
  410. db.commit()
  411. return {"statusCode": 200, "msg": "删除成功"}
  412. class DeleteHistoryRequest(BaseModel):
  413. ai_conversation_id: int
  414. @router.post("/delete_history_record")
  415. async def delete_history_record(
  416. request: Request, data: DeleteHistoryRequest, db: Session = Depends(get_db)
  417. ):
  418. """删除历史记录(软删除)"""
  419. user = request.state.user
  420. if not user:
  421. return {"statusCode": 401, "msg": "未授权"}
  422. db.query(AIConversation).filter(
  423. AIConversation.id == data.ai_conversation_id,
  424. AIConversation.user_id == user.user_id,
  425. ).update({"is_deleted": 1, "updated_at": int(time.time())})
  426. db.commit()
  427. return {"statusCode": 200, "msg": "删除成功"}
  428. # ─────────────────────────────────────────────────────────────────────────
  429. # 流式接口 /stream/chat(无 DB,意图识别 + RAG)
  430. # ─────────────────────────────────────────────────────────────────────────
  431. class StreamChatRequest(BaseModel):
  432. message: str
  433. model: str = ""
  434. @router.post("/stream/chat")
  435. async def stream_chat(request: Request, data: StreamChatRequest):
  436. """流式聊天(SSE,不写 DB)"""
  437. message = data.message.strip()
  438. if not message:
  439. return JSONResponse(content={"statusCode": 400, "msg": "消息不能为空"})
  440. async def event_generator():
  441. intent_type = ""
  442. try:
  443. intent_result = await qwen_service.intent_recognition(message)
  444. if isinstance(intent_result, dict):
  445. intent_type = (
  446. intent_result.get("intent_type") or intent_result.get("intent") or ""
  447. ).lower()
  448. except Exception as ie:
  449. logger.warning(f"[stream/chat] 意图识别异常: {ie}")
  450. rag_context = ""
  451. if intent_type in ("query_knowledge_base", "知识库查询", "技术咨询"):
  452. rag_context = await _rag_search(message)
  453. # 使用prompt加载器加载最终回答prompt
  454. system_content = load_prompt(
  455. "final_answer",
  456. userMessage=message,
  457. contextJSON=rag_context if rag_context else "暂无相关知识库内容"
  458. )
  459. messages = [
  460. {"role": "user", "content": system_content},
  461. ]
  462. try:
  463. async for chunk in qwen_service.stream_chat(messages):
  464. yield f"data: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
  465. except Exception as e:
  466. logger.error(f"[stream/chat] 流式输出异常: {e}")
  467. yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
  468. finally:
  469. yield "data: [DONE]\n\n"
  470. return StreamingResponse(event_generator(), media_type="text/event-stream")
  471. # ─────────────────────────────────────────────────────────────────────────
  472. # 流式接口 /stream/chat-with-db(前端主聊天接口)
  473. # ─────────────────────────────────────────────────────────────────────────
  474. class StreamChatWithDBRequest(BaseModel):
  475. message: str
  476. ai_conversation_id: int = 0
  477. business_type: int = 0
  478. exam_name: str = ""
  479. ai_message_id: int = 0
  480. online_search_content: str = ""
  481. @router.post("/stream/chat-with-db")
  482. async def stream_chat_with_db(request: Request, data: StreamChatWithDBRequest):
  483. """
  484. 带 DB 操作的流式聊天(SSE)
  485. 流程:
  486. 1. 创建/获取对话
  487. 2. 插入用户消息和 AI 占位消息
  488. 3. 发送 initial 事件
  489. 4. RAG 检索
  490. 5. 构建历史上下文
  491. 6. 流式输出
  492. 7. 更新 AI 消息内容
  493. """
  494. user = request.state.user
  495. if not user:
  496. return JSONResponse(content={"statusCode": 401, "msg": "未授权"})
  497. message = data.message.strip()
  498. if not message:
  499. return JSONResponse(content={"statusCode": 400, "msg": "消息不能为空"})
  500. async def event_generator():
  501. db = SessionLocal()
  502. try:
  503. # 1. 创建或获取对话
  504. if data.ai_conversation_id == 0:
  505. conversation = AIConversation(
  506. user_id=user.user_id,
  507. content=_build_conversation_preview(message, limit=100),
  508. business_type=data.business_type,
  509. exam_name=data.exam_name,
  510. created_at=int(time.time()),
  511. updated_at=int(time.time()),
  512. is_deleted=0,
  513. )
  514. db.add(conversation)
  515. db.commit()
  516. db.refresh(conversation)
  517. conv_id = conversation.id
  518. else:
  519. existing_conversation = (
  520. db.query(AIConversation)
  521. .filter(
  522. AIConversation.id == data.ai_conversation_id,
  523. AIConversation.user_id == user.user_id,
  524. AIConversation.is_deleted == 0,
  525. )
  526. .first()
  527. )
  528. if existing_conversation:
  529. conv_id = existing_conversation.id
  530. db.query(AIConversation).filter(
  531. AIConversation.id == conv_id,
  532. AIConversation.user_id == user.user_id,
  533. ).update(
  534. {
  535. "content": _build_conversation_preview(message, limit=100),
  536. "business_type": data.business_type,
  537. "exam_name": data.exam_name if data.business_type == 3 else "",
  538. "updated_at": int(time.time()),
  539. }
  540. )
  541. db.commit()
  542. else:
  543. conversation = AIConversation(
  544. user_id=user.user_id,
  545. content=_build_conversation_preview(message, limit=100),
  546. business_type=data.business_type,
  547. exam_name=data.exam_name if data.business_type == 3 else "",
  548. created_at=int(time.time()),
  549. updated_at=int(time.time()),
  550. is_deleted=0,
  551. )
  552. db.add(conversation)
  553. db.commit()
  554. db.refresh(conversation)
  555. conv_id = conversation.id
  556. # 2. 插入用户消息
  557. user_msg = AIMessage(
  558. ai_conversation_id=conv_id,
  559. user_id=user.user_id,
  560. type="user",
  561. content=message,
  562. created_at=int(time.time()),
  563. updated_at=int(time.time()),
  564. is_deleted=0,
  565. )
  566. db.add(user_msg)
  567. db.commit()
  568. db.refresh(user_msg)
  569. # 3. 插入 AI 占位消息
  570. ai_msg = AIMessage(
  571. ai_conversation_id=conv_id,
  572. user_id=user.user_id,
  573. type="ai",
  574. content="",
  575. prev_user_id=user_msg.id,
  576. created_at=int(time.time()),
  577. updated_at=int(time.time()),
  578. is_deleted=0,
  579. )
  580. db.add(ai_msg)
  581. db.commit()
  582. db.refresh(ai_msg)
  583. # 4. 发送 initial 事件
  584. yield f"data: {json.dumps({'type': 'initial', 'ai_conversation_id': conv_id, 'ai_message_id': ai_msg.id}, ensure_ascii=False)}\n\n"
  585. # 5. RAG 检索
  586. rag_context = await _rag_search(message, top_k=10)
  587. # 6. 获取历史上下文(最近 4 条,2 轮对话)
  588. history_msgs = (
  589. db.query(AIMessage)
  590. .filter(
  591. AIMessage.ai_conversation_id == conv_id,
  592. AIMessage.id < ai_msg.id,
  593. AIMessage.is_deleted == 0,
  594. )
  595. .order_by(AIMessage.updated_at.desc())
  596. .limit(4)
  597. .all()
  598. )
  599. history_msgs.reverse()
  600. history_context = ""
  601. for msg in history_msgs:
  602. role = "用户" if msg.type == "user" else "助手"
  603. history_context += f"{role}: {msg.content}\n\n"
  604. # 7. 构建完整 prompt
  605. # 构建上下文JSON
  606. context_parts = []
  607. if rag_context:
  608. context_parts.append(f"知识库内容:\n{rag_context}")
  609. if data.online_search_content:
  610. context_parts.append(f"联网搜索结果:\n{data.online_search_content}")
  611. context_json = "\n\n".join(context_parts) if context_parts else "暂无相关知识库内容"
  612. # 使用prompt加载器加载最终回答prompt
  613. system_content = load_prompt(
  614. "final_answer",
  615. userMessage=message,
  616. contextJSON=context_json,
  617. historyContext=history_context if history_context else ""
  618. )
  619. messages = [
  620. {"role": "user", "content": system_content},
  621. ]
  622. # 8. 流式输出并收集完整回复
  623. full_response = ""
  624. try:
  625. async for chunk in qwen_service.stream_chat(messages):
  626. escaped_chunk = chunk.replace("\n", "\\n")
  627. full_response += chunk
  628. yield f"data: {escaped_chunk}\n\n"
  629. except Exception as e:
  630. logger.error(f"[stream/chat-with-db] 流式输出异常: {e}")
  631. yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
  632. # 9. 更新 AI 消息内容
  633. if full_response:
  634. now_ts = int(time.time())
  635. db.query(AIMessage).filter(AIMessage.id == ai_msg.id).update(
  636. {"content": full_response, "updated_at": now_ts}
  637. )
  638. db.query(AIConversation).filter(
  639. AIConversation.id == conv_id,
  640. AIConversation.user_id == user.user_id,
  641. ).update(
  642. {
  643. "content": _build_conversation_preview(message, limit=100),
  644. "business_type": data.business_type,
  645. "exam_name": data.exam_name if data.business_type == 3 else "",
  646. "updated_at": now_ts,
  647. }
  648. )
  649. db.commit()
  650. # 10. 结束标记
  651. yield "data: [DONE]\n\n"
  652. except Exception as e:
  653. logger.error(f"[stream/chat-with-db] 处理异常: {e}")
  654. yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
  655. finally:
  656. db.close()
  657. return StreamingResponse(event_generator(), media_type="text/event-stream")
  658. # ─────────────────────────────────────────────────────────────────────────
  659. # 猜你想问
  660. # ─────────────────────────────────────────────────────────────────────────
  661. class GuessYouWantRequest(BaseModel):
  662. ai_message_id: int
  663. @router.post("/guess_you_want")
  664. async def guess_you_want(
  665. request: Request,
  666. data: GuessYouWantRequest,
  667. db: Session = Depends(get_db),
  668. ):
  669. """生成"猜你想问"的3个关联问题,保存到 AIMessage.guess_you_want"""
  670. user = request.state.user
  671. if not user:
  672. return {"statusCode": 401, "msg": "未授权"}
  673. try:
  674. ai_msg = (
  675. db.query(AIMessage)
  676. .filter(AIMessage.id == data.ai_message_id, AIMessage.is_deleted == 0)
  677. .first()
  678. )
  679. if not ai_msg:
  680. return {"statusCode": 404, "msg": "消息不存在"}
  681. # 使用prompt加载器加载猜你想问prompt
  682. system_content = load_prompt(
  683. "guess_questions",
  684. currentContent=ai_msg.content[:500]
  685. )
  686. messages = [
  687. {"role": "user", "content": system_content},
  688. ]
  689. response = await qwen_service.chat(messages)
  690. try:
  691. # 尝试从响应中提取 JSON
  692. import re
  693. json_match = re.search(r'\{[^{}]*"questions"[^{}]*\}', response, re.DOTALL)
  694. if json_match:
  695. response_json = json.loads(json_match.group())
  696. else:
  697. response_json = json.loads(response)
  698. questions = response_json.get("questions", [])
  699. except Exception:
  700. lines = [l.strip() for l in response.split("\n") if l.strip()]
  701. questions = []
  702. for line in lines:
  703. clean = line.lstrip("0123456789.-、 ").strip()
  704. if clean and len(clean) > 5:
  705. questions.append(clean)
  706. if not questions:
  707. questions = ["该话题的具体应用场景?", "有哪些注意事项?", "相关案例分析?"]
  708. questions = questions[:3]
  709. while len(questions) < 3:
  710. questions.append("更多相关问题")
  711. guess_json = json.dumps({"questions": questions}, ensure_ascii=False)
  712. db.query(AIMessage).filter(AIMessage.id == data.ai_message_id).update(
  713. {"guess_you_want": guess_json, "updated_at": int(time.time())}
  714. )
  715. db.commit()
  716. return {
  717. "statusCode": 200,
  718. "msg": "success",
  719. "data": {"ai_message_id": data.ai_message_id, "questions": questions},
  720. }
  721. except Exception as e:
  722. logger.error(f"[guess_you_want] 处理异常: {e}")
  723. return {"statusCode": 500, "msg": f"处理失败: {str(e)}"}
  724. # ─────────────────────────────────────────────────────────────────────────
  725. # 在线搜索(Dify 工作流集成)
  726. # ─────────────────────────────────────────────────────────────────────────
  727. @router.get("/online_search")
  728. async def online_search(question: str, request: Request, db: Session = Depends(get_db)):
  729. """
  730. 在线搜索
  731. 流程:Qwen 提炼关键词 → Dify 工作流 → 返回摘要
  732. """
  733. user = request.state.user
  734. if not user:
  735. return {"statusCode": 401, "msg": "未授权"}
  736. try:
  737. keywords = await qwen_service.extract_keywords(question)
  738. dify_config = getattr(settings, "dify", None)
  739. if not dify_config or not getattr(dify_config, "workflow_url", None):
  740. return {"statusCode": 500, "msg": "Dify 配置未设置"}
  741. headers = {
  742. "Authorization": f"Bearer {dify_config.auth_token}",
  743. "Content-Type": "application/json",
  744. }
  745. payload = {
  746. "workflow_id": dify_config.workflow_id,
  747. "inputs": {
  748. "keywords": keywords,
  749. "num": 5, # 搜索结果数量
  750. "max_text_len": 4000 # 最大文本长度
  751. },
  752. "response_mode": "blocking",
  753. "user": getattr(user, "account", str(user.user_id)),
  754. }
  755. async with httpx.AsyncClient(timeout=30.0) as client:
  756. resp = await client.post(dify_config.workflow_url, headers=headers, json=payload)
  757. if resp.status_code != 200:
  758. logger.error(f"[online_search] Dify 调用失败: {resp.status_code}, 响应: {resp.text}")
  759. return {"statusCode": 500, "msg": f"搜索服务异常: {resp.status_code}"}
  760. result = resp.json()
  761. search_text = result.get("data", {}).get("outputs", {}).get("text", "")
  762. return {
  763. "statusCode": 200,
  764. "msg": "success",
  765. "data": {"keywords": keywords, "result": search_text},
  766. }
  767. except Exception as e:
  768. logger.error(f"[online_search] 处理异常: {e}")
  769. return {"statusCode": 500, "msg": f"搜索失败: {str(e)}"}
  770. class SaveOnlineSearchResultRequest(BaseModel):
  771. ai_message_id: int
  772. search_result: str
  773. @router.post("/save_online_search_result")
  774. async def save_online_search_result(
  775. request: Request,
  776. data: SaveOnlineSearchResultRequest,
  777. db: Session = Depends(get_db),
  778. ):
  779. """保存联网搜索结果到 AIMessage.search_source"""
  780. user = request.state.user
  781. if not user:
  782. return {"statusCode": 401, "msg": "未授权"}
  783. try:
  784. db.query(AIMessage).filter(AIMessage.id == data.ai_message_id).update(
  785. {"search_source": data.search_result, "updated_at": int(time.time())}
  786. )
  787. db.commit()
  788. return {"statusCode": 200, "msg": "保存成功"}
  789. except Exception as e:
  790. logger.error(f"[save_online_search_result] 处理异常: {e}")
  791. return {"statusCode": 500, "msg": f"保存失败: {str(e)}"}
  792. # ─────────────────────────────────────────────────────────────────────────
  793. # 意图识别独立接口
  794. # ─────────────────────────────────────────────────────────────────────────
  795. class IntentRecognitionRequest(BaseModel):
  796. message: str
  797. save_to_db: bool = False
  798. ai_conversation_id: int = 0
  799. @router.post("/intent_recognition")
  800. async def intent_recognition(
  801. request: Request,
  802. data: IntentRecognitionRequest,
  803. db: Session = Depends(get_db),
  804. ):
  805. """独立意图识别接口;若为 greeting/faq 且 save_to_db=True 则直接存 DB"""
  806. user = request.state.user
  807. if not user:
  808. return {"statusCode": 401, "msg": "未授权"}
  809. try:
  810. intent_result = await qwen_service.intent_recognition(data.message)
  811. intent_type = ""
  812. response_text = ""
  813. if isinstance(intent_result, dict):
  814. intent_type = (
  815. intent_result.get("intent_type") or intent_result.get("intent") or ""
  816. ).lower()
  817. response_text = intent_result.get("response", "")
  818. if data.save_to_db and intent_type in ("greeting", "问候", "faq", "常见问题"):
  819. if data.ai_conversation_id == 0:
  820. conversation = AIConversation(
  821. user_id=user.user_id,
  822. content=data.message[:100],
  823. business_type=0,
  824. created_at=int(time.time()),
  825. updated_at=int(time.time()),
  826. is_deleted=0,
  827. )
  828. db.add(conversation)
  829. db.commit()
  830. db.refresh(conversation)
  831. conv_id = conversation.id
  832. else:
  833. conv_id = data.ai_conversation_id
  834. user_msg = AIMessage(
  835. ai_conversation_id=conv_id,
  836. user_id=user.user_id,
  837. type="user",
  838. content=data.message,
  839. created_at=int(time.time()),
  840. updated_at=int(time.time()),
  841. is_deleted=0,
  842. )
  843. db.add(user_msg)
  844. db.commit()
  845. ai_msg = AIMessage(
  846. ai_conversation_id=conv_id,
  847. user_id=user.user_id,
  848. type="ai",
  849. content=response_text,
  850. prev_user_id=user_msg.id,
  851. created_at=int(time.time()),
  852. updated_at=int(time.time()),
  853. is_deleted=0,
  854. )
  855. db.add(ai_msg)
  856. db.commit()
  857. db.refresh(ai_msg)
  858. return {
  859. "statusCode": 200,
  860. "msg": "success",
  861. "data": {
  862. "intent_type": intent_type,
  863. "response": response_text,
  864. "ai_conversation_id": conv_id,
  865. "ai_message_id": ai_msg.id,
  866. "saved_to_db": True,
  867. },
  868. }
  869. return {
  870. "statusCode": 200,
  871. "msg": "success",
  872. "data": {
  873. "intent_type": intent_type,
  874. "response": response_text,
  875. "saved_to_db": False,
  876. },
  877. }
  878. except Exception as e:
  879. logger.error(f"[intent_recognition] 处理异常: {e}")
  880. return {"statusCode": 500, "msg": f"处理失败: {str(e)}"}
  881. # ─────────────────────────────────────────────────────────────────────────
  882. # 获取用户推荐问题(模糊查询 QA / RecommendQuestion 表)
  883. # ─────────────────────────────────────────────────────────────────────────
  884. @router.get("/get_user_recommend_question")
  885. async def get_user_recommend_question(
  886. keyword: str = "",
  887. limit: int = 10,
  888. db: Session = Depends(get_db),
  889. ):
  890. """获取推荐问题(支持模糊查询)"""
  891. try:
  892. query = db.query(RecommendQuestion).filter(RecommendQuestion.is_deleted == 0)
  893. if keyword:
  894. query = query.filter(RecommendQuestion.question.like(f"%{keyword}%"))
  895. questions = query.order_by(RecommendQuestion.id.desc()).limit(limit).all()
  896. return {
  897. "statusCode": 200,
  898. "msg": "success",
  899. "data": [
  900. {"id": q.id, "question": q.question, "created_at": q.created_at}
  901. for q in questions
  902. ],
  903. }
  904. except Exception as e:
  905. logger.error(f"[get_user_recommend_question] 处理异常: {e}")
  906. return {"statusCode": 500, "msg": f"查询失败: {str(e)}"}
  907. # ─────────────────────────────────────────────────────────────────────────
  908. # PPT 大纲 / 文档编辑保存
  909. # ─────────────────────────────────────────────────────────────────────────
  910. class SavePPTOutlineRequest(BaseModel):
  911. ai_message_id: int
  912. content: str
  913. @router.post("/save_ppt_outline")
  914. async def save_ppt_outline(
  915. request: Request,
  916. data: SavePPTOutlineRequest,
  917. db: Session = Depends(get_db),
  918. ):
  919. """更新 AIMessage.content 保存 PPT 大纲内容"""
  920. user = request.state.user
  921. if not user:
  922. return {"statusCode": 401, "msg": "未授权"}
  923. try:
  924. db.query(AIMessage).filter(AIMessage.id == data.ai_message_id).update(
  925. {"content": data.content, "updated_at": int(time.time())}
  926. )
  927. db.commit()
  928. return {"statusCode": 200, "msg": "保存成功"}
  929. except Exception as e:
  930. logger.error(f"[save_ppt_outline] 处理异常: {e}")
  931. return {"statusCode": 500, "msg": f"保存失败: {str(e)}"}
  932. class SaveEditDocumentRequest(BaseModel):
  933. ai_message_id: int
  934. content: str
  935. @router.post("/save_edit_document")
  936. async def save_edit_document(
  937. request: Request,
  938. data: SaveEditDocumentRequest,
  939. db: Session = Depends(get_db),
  940. ):
  941. """更新 ai 类型 AIMessage.content(AI写作编辑保存)"""
  942. user = request.state.user
  943. if not user:
  944. return {"statusCode": 401, "msg": "未授权"}
  945. try:
  946. db.query(AIMessage).filter(
  947. AIMessage.id == data.ai_message_id,
  948. AIMessage.type == "ai",
  949. ).update({"content": data.content, "updated_at": int(time.time())})
  950. db.commit()
  951. return {"statusCode": 200, "msg": "保存成功"}
  952. except Exception as e:
  953. logger.error(f"[save_edit_document] 处理异常: {e}")
  954. return {"statusCode": 500, "msg": f"保存失败: {str(e)}"}