chat.py 46 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227
  1. from fastapi import APIRouter, Depends, Request
  2. from fastapi.responses import StreamingResponse, JSONResponse
  3. from sqlalchemy import or_
  4. from sqlalchemy.orm import Session
  5. from pydantic import BaseModel
  6. from typing import Optional
  7. from database import get_db, SessionLocal
  8. from models.chat import AIConversation, AIMessage
  9. from models.total import RecommendQuestion
  10. from utils.config import settings
  11. from utils.logger import logger
  12. from services.qwen_service import qwen_service
  13. from utils.prompt_loader import load_prompt
  14. import time
  15. import json
  16. import httpx
  17. router = APIRouter()
  18. def _build_conversation_preview(content: str, limit: int = 50) -> str:
  19. content = (content or "").strip()
  20. if len(content) <= limit:
  21. return content
  22. return content[:limit] + "..."
  23. def _to_frontend_timestamp(timestamp: Optional[int]) -> Optional[int]:
  24. if not timestamp:
  25. return None
  26. return timestamp if timestamp >= 10**12 else timestamp * 1000
  27. def _build_conversation_title(conversation: AIConversation) -> str:
  28. if conversation.business_type == 3 and (conversation.exam_name or "").strip():
  29. return conversation.exam_name.strip()
  30. return _build_conversation_preview(conversation.content or "", limit=30)
  31. def _is_exam_workshop_conversation(conversation: Optional[AIConversation]) -> bool:
  32. if not conversation:
  33. return False
  34. return conversation.business_type == 3 or bool((conversation.exam_name or "").strip())
  35. def _resolve_conversation_metadata(
  36. conversation: Optional[AIConversation],
  37. requested_business_type: int,
  38. requested_exam_name: str,
  39. ) -> tuple[int, str]:
  40. requested_exam_name = (requested_exam_name or "").strip()
  41. if _is_exam_workshop_conversation(conversation):
  42. return 3, requested_exam_name or (conversation.exam_name or "").strip()
  43. if requested_business_type == 3:
  44. return 3, requested_exam_name
  45. return requested_business_type, ""
  46. def _refresh_conversation_snapshot(db: Session, conversation_id: int, user_id: int) -> None:
  47. latest_message = (
  48. db.query(AIMessage)
  49. .filter(
  50. AIMessage.ai_conversation_id == conversation_id,
  51. AIMessage.user_id == user_id,
  52. AIMessage.is_deleted == 0,
  53. )
  54. .order_by(AIMessage.id.desc())
  55. .first()
  56. )
  57. if not latest_message:
  58. db.query(AIConversation).filter(
  59. AIConversation.id == conversation_id,
  60. AIConversation.user_id == user_id,
  61. ).update({"is_deleted": 1, "updated_at": int(time.time())})
  62. return
  63. latest_user_message = (
  64. db.query(AIMessage)
  65. .filter(
  66. AIMessage.ai_conversation_id == conversation_id,
  67. AIMessage.user_id == user_id,
  68. AIMessage.type == "user",
  69. AIMessage.is_deleted == 0,
  70. )
  71. .order_by(AIMessage.id.desc())
  72. .first()
  73. )
  74. preview_source = (
  75. latest_user_message.content
  76. if latest_user_message and latest_user_message.content
  77. else latest_message.content
  78. )
  79. preview_content = _build_conversation_preview(
  80. preview_source or "", limit=100)
  81. db.query(AIConversation).filter(
  82. AIConversation.id == conversation_id,
  83. AIConversation.user_id == user_id,
  84. ).update(
  85. {
  86. "content": preview_content or " ",
  87. "updated_at": int(time.time()),
  88. }
  89. )
  90. # ─────────────────────────────────────────────────────────────────────────
  91. # 辅助函数
  92. # ─────────────────────────────────────────────────────────────────────────
  93. async def _rag_search(message: str, top_k: int = 5) -> str:
  94. """调用 search API 做 RAG 检索,返回上下文文本"""
  95. try:
  96. search_cfg = getattr(settings, 'search', None)
  97. if not search_cfg or not hasattr(search_cfg, 'api_url'):
  98. return ""
  99. search_url = search_cfg.api_url
  100. if not search_url:
  101. return ""
  102. async with httpx.AsyncClient(timeout=10.0) as client:
  103. resp = await client.post(
  104. search_url,
  105. json={"query": message, "n_results": top_k},
  106. )
  107. if resp.status_code == 200:
  108. data = resp.json()
  109. docs = data.get("results") or data.get("documents") or []
  110. return "\n\n".join(
  111. d.get("content") or d.get("text") or str(d)
  112. for d in docs[:top_k]
  113. if d.get("content") or d.get("text")
  114. )
  115. except Exception as e:
  116. logger.warning(f"[RAG] 检索失败(可忽略): {e}")
  117. return ""
  118. def _build_history_messages(conv_id: int, limit: int = 10) -> list:
  119. """从数据库读取最近对话历史,构建 messages 列表"""
  120. db = SessionLocal()
  121. try:
  122. msgs = (
  123. db.query(AIMessage)
  124. .filter(AIMessage.ai_conversation_id == conv_id, AIMessage.is_deleted == 0)
  125. .order_by(AIMessage.id.desc())
  126. .limit(limit)
  127. .all()
  128. )
  129. msgs.reverse()
  130. result = []
  131. for m in msgs:
  132. role = "user" if m.type == "user" else "assistant"
  133. if m.content:
  134. result.append({"role": role, "content": m.content})
  135. return result
  136. finally:
  137. db.close()
  138. # ─────────────────────────────────────────────────────────────────────────
  139. # 非流式接口
  140. # ─────────────────────────────────────────────────────────────────────────
  141. class SendMessageRequest(BaseModel):
  142. message: str
  143. conversation_id: Optional[int] = None
  144. ai_conversation_id: Optional[int] = None
  145. business_type: int = 0 # 0=AI问答, 1=PPT大纲, 2=AI写作, 3=考试工坊
  146. exam_name: str = ""
  147. ai_message_id: int = 0
  148. @router.post("/send_deepseek_message")
  149. async def send_deepseek_message(
  150. request: Request,
  151. data: SendMessageRequest,
  152. db: Session = Depends(get_db),
  153. ):
  154. """
  155. 发送消息(非流式)
  156. 支持多种业务类型:
  157. - 0: AI问答(意图识别 + RAG)
  158. - 1: PPT大纲生成
  159. - 2: AI写作
  160. - 3: 考试工坊
  161. """
  162. user = request.state.user
  163. if not user:
  164. return {"statusCode": 401, "msg": "未授权"}
  165. try:
  166. message = data.message.strip()
  167. if not message:
  168. return {"statusCode": 400, "msg": "消息不能为空"}
  169. conversation_id = data.conversation_id or data.ai_conversation_id
  170. existing_conversation = None
  171. if conversation_id:
  172. existing_conversation = (
  173. db.query(AIConversation)
  174. .filter(
  175. AIConversation.id == conversation_id,
  176. AIConversation.user_id == user.user_id,
  177. AIConversation.is_deleted == 0,
  178. )
  179. .first()
  180. )
  181. effective_business_type, effective_exam_name = _resolve_conversation_metadata(
  182. existing_conversation,
  183. data.business_type,
  184. data.exam_name,
  185. )
  186. # 创建或获取对话
  187. if not conversation_id:
  188. conversation = AIConversation(
  189. user_id=user.user_id,
  190. content=message[:100],
  191. business_type=effective_business_type,
  192. exam_name=effective_exam_name,
  193. created_at=int(time.time()),
  194. updated_at=int(time.time()),
  195. is_deleted=0,
  196. )
  197. db.add(conversation)
  198. db.commit()
  199. db.refresh(conversation)
  200. conv_id = conversation.id
  201. else:
  202. conv_id = conversation_id
  203. db.query(AIConversation).filter(
  204. AIConversation.id == conv_id,
  205. AIConversation.user_id == user.user_id,
  206. AIConversation.is_deleted == 0,
  207. ).update({
  208. "content": message[:100],
  209. "business_type": effective_business_type,
  210. "exam_name": effective_exam_name,
  211. "updated_at": int(time.time()),
  212. })
  213. db.commit()
  214. response_text = ""
  215. if data.business_type == 0:
  216. # AI问答:意图识别 + RAG
  217. try:
  218. intent_result = await qwen_service.intent_recognition(message)
  219. intent_type = ""
  220. if isinstance(intent_result, dict):
  221. intent_type = (
  222. intent_result.get("intent_type") or intent_result.get(
  223. "intent") or ""
  224. ).lower()
  225. rag_context = ""
  226. if intent_type in ("query_knowledge_base", "知识库查询", "技术咨询"):
  227. rag_context = await _rag_search(message, top_k=10)
  228. # 使用prompt加载器加载最终回答prompt
  229. system_content = load_prompt(
  230. "final_answer",
  231. userMessage=message,
  232. contextJSON=rag_context if rag_context else "暂无相关知识库内容"
  233. )
  234. messages = [
  235. {"role": "user", "content": system_content},
  236. ]
  237. qwen_response = await qwen_service.chat(messages)
  238. try:
  239. if isinstance(qwen_response, str) and qwen_response.strip().startswith("{"):
  240. response_json = json.loads(qwen_response)
  241. response_text = response_json.get(
  242. "natural_language_answer", qwen_response)
  243. else:
  244. response_text = qwen_response
  245. except Exception:
  246. response_text = qwen_response
  247. except Exception as e:
  248. logger.error(f"[send_deepseek_message] AI问答异常: {e}")
  249. response_text = f"处理失败: {str(e)}"
  250. elif data.business_type == 1:
  251. # PPT大纲生成
  252. try:
  253. rag_context = await _rag_search(message, top_k=10)
  254. # 使用prompt加载器加载PPT大纲生成prompt
  255. system_content = load_prompt(
  256. "ppt_outline",
  257. userMessage=message,
  258. contextJSON=rag_context if rag_context else "暂无相关知识库内容"
  259. )
  260. messages = [
  261. {"role": "user", "content": system_content},
  262. ]
  263. response_text = await qwen_service.chat(messages)
  264. except Exception as e:
  265. logger.error(f"[send_deepseek_message] PPT大纲生成异常: {e}")
  266. response_text = f"处理失败: {str(e)}"
  267. elif data.business_type == 2:
  268. # AI写作
  269. try:
  270. rag_context = await _rag_search(message, top_k=10)
  271. # 使用prompt加载器加载公文写作prompt
  272. system_content = load_prompt(
  273. "document_writing",
  274. userMessage=message,
  275. contextJSON=rag_context if rag_context else "暂无相关知识库内容"
  276. )
  277. messages = [
  278. {"role": "user", "content": system_content},
  279. ]
  280. response_text = await qwen_service.chat(messages)
  281. except Exception as e:
  282. logger.error(f"[send_deepseek_message] AI写作异常: {e}")
  283. response_text = f"处理失败: {str(e)}"
  284. elif data.business_type == 3:
  285. # 考试工坊:生成题目
  286. try:
  287. system_content = (
  288. "你是一个专业的考试题目生成助手,专注于路桥隧轨施工安全领域。\n"
  289. "请根据用户需求生成专业的考试题目,包括单选题、多选题、判断题、简答题等。\n"
  290. "用户消息中已经包含考试标题、题型要求和出题依据内容,必须以其中的出题依据内容为核心生成题目,不能脱离依据内容自由发挥。\n"
  291. "题干、选项、答案和解析都要与出题依据内容中的知识点、专业术语、操作流程、规范要求或培训主题直接相关。\n"
  292. "输出必须是可直接 JSON.parse 的纯 JSON,不要包含 markdown 代码块、解释文字或额外前后缀。\n"
  293. "JSON 顶层结构必须包含 singleChoice、judge、multiple、short 四个字段。\n"
  294. "singleChoice.questions 和 multiple.questions 中每道题必须包含 text、options、answer、analysis。\n"
  295. "options 必须是数组,元素格式为 {\"key\":\"A\",\"text\":\"具体选项内容\"},禁止输出“选项A”这类占位文本。\n"
  296. "judge.questions 中每道题必须包含 text、answer、analysis。\n"
  297. "short.questions 中每道题必须包含 text、outline,其中 outline 至少包含 keyFactors。\n"
  298. "所有题目内容、选项内容、答案和解析都要结合用户给出的工程类型、题型数量、分值和课件内容具体生成。"
  299. )
  300. messages = [
  301. {"role": "system", "content": system_content},
  302. {"role": "user", "content": message},
  303. ]
  304. response_text = await qwen_service.chat(messages)
  305. now_ts = int(time.time())
  306. user_message = AIMessage(
  307. ai_conversation_id=conv_id,
  308. user_id=user.user_id,
  309. type="user",
  310. content=message,
  311. created_at=now_ts,
  312. updated_at=now_ts,
  313. is_deleted=0,
  314. )
  315. db.add(user_message)
  316. db.commit()
  317. db.refresh(user_message)
  318. ai_message = AIMessage(
  319. ai_conversation_id=conv_id,
  320. user_id=user.user_id,
  321. type="ai",
  322. content=response_text,
  323. prev_user_id=user_message.id,
  324. created_at=now_ts,
  325. updated_at=now_ts,
  326. is_deleted=0,
  327. )
  328. db.add(ai_message)
  329. db.commit()
  330. _refresh_conversation_snapshot(db, conv_id, user.user_id)
  331. db.commit()
  332. if effective_exam_name:
  333. db.query(AIConversation).filter(AIConversation.id == conv_id).update(
  334. {"business_type": 3,
  335. "exam_name": effective_exam_name,
  336. "updated_at": int(time.time())}
  337. )
  338. db.commit()
  339. except Exception as e:
  340. logger.error(f"[send_deepseek_message] 考试工坊异常: {e}")
  341. response_text = f"处理失败: {str(e)}"
  342. else:
  343. return {"statusCode": 400, "msg": f"不支持的业务类型: {data.business_type}"}
  344. return {
  345. "statusCode": 200,
  346. "msg": "success",
  347. "data": {
  348. "conversation_id": conv_id,
  349. "ai_conversation_id": conv_id,
  350. "response": response_text,
  351. "reply": response_text,
  352. "content": response_text,
  353. "message": response_text,
  354. "user_id": user.user_id,
  355. "business_type": effective_business_type,
  356. },
  357. }
  358. except Exception as e:
  359. logger.error(f"[send_deepseek_message] 异常: {e}")
  360. return {"statusCode": 500, "msg": f"处理失败: {str(e)}"}
  361. @router.get("/get_history_record")
  362. async def get_history_record(
  363. request: Request,
  364. ai_conversation_id: int = 0,
  365. business_type: Optional[int] = None,
  366. db: Session = Depends(get_db),
  367. ):
  368. """兼容前端的历史记录查询:ai_conversation_id=0 返回对话列表,否则返回消息详情。"""
  369. user = request.state.user
  370. if not user:
  371. return {"statusCode": 401, "msg": "未授权"}
  372. if ai_conversation_id > 0:
  373. messages = (
  374. db.query(AIMessage)
  375. .filter(
  376. AIMessage.ai_conversation_id == ai_conversation_id,
  377. AIMessage.user_id == user.user_id,
  378. AIMessage.is_deleted == 0,
  379. )
  380. .order_by(AIMessage.id.asc())
  381. .all()
  382. )
  383. return {
  384. "statusCode": 200,
  385. "msg": "success",
  386. "total": len(messages),
  387. "data": [
  388. {
  389. "id": message.id,
  390. "ai_conversation_id": message.ai_conversation_id,
  391. "user_id": message.user_id,
  392. "type": message.type,
  393. "content": message.content,
  394. "user_feedback": message.user_feedback,
  395. "prev_user_id": message.prev_user_id,
  396. "search_source": message.search_source or "",
  397. "guess_you_want": message.guess_you_want or "",
  398. "created_at": _to_frontend_timestamp(message.created_at),
  399. "updated_at": _to_frontend_timestamp(message.updated_at),
  400. }
  401. for message in messages
  402. ],
  403. }
  404. conversations_query = db.query(AIConversation).filter(
  405. AIConversation.user_id == user.user_id,
  406. AIConversation.is_deleted == 0,
  407. )
  408. if business_type is not None:
  409. if business_type == 3:
  410. conversations_query = conversations_query.filter(
  411. or_(
  412. AIConversation.business_type == 3,
  413. AIConversation.exam_name.isnot(None),
  414. AIConversation.exam_name != "",
  415. )
  416. )
  417. else:
  418. conversations_query = conversations_query.filter(
  419. AIConversation.business_type == business_type
  420. )
  421. total = conversations_query.count()
  422. conversations = (
  423. conversations_query
  424. .order_by(AIConversation.updated_at.desc(), AIConversation.id.desc())
  425. .limit(50)
  426. .all()
  427. )
  428. return {
  429. "statusCode": 200,
  430. "msg": "success",
  431. "total": total,
  432. "data": [
  433. {
  434. "id": conv.id,
  435. "title": _build_conversation_title(conv),
  436. "content": conv.content or "",
  437. "business_type": conv.business_type,
  438. "exam_name": conv.exam_name or "",
  439. "created_at": _to_frontend_timestamp(conv.created_at),
  440. "updated_at": _to_frontend_timestamp(conv.updated_at),
  441. }
  442. for conv in conversations
  443. ],
  444. }
  445. class DeleteConversationRequest(BaseModel):
  446. ai_conversation_id: int = 0
  447. ai_message_id: int = 0
  448. @router.post("/delete_conversation")
  449. async def delete_conversation(
  450. request: Request, data: DeleteConversationRequest, db: Session = Depends(get_db)
  451. ):
  452. """
  453. 删除对话(软删除)
  454. 同时软删除对话记录和所有关联的消息
  455. """
  456. user = request.state.user
  457. if not user:
  458. return {"statusCode": 401, "msg": "未授权"}
  459. now_ts = int(time.time())
  460. if data.ai_message_id:
  461. ai_message = (
  462. db.query(AIMessage)
  463. .filter(
  464. AIMessage.id == data.ai_message_id,
  465. AIMessage.user_id == user.user_id,
  466. AIMessage.type == "ai",
  467. AIMessage.is_deleted == 0,
  468. )
  469. .first()
  470. )
  471. if not ai_message:
  472. return {"statusCode": 404, "msg": "消息不存在"}
  473. db.query(AIMessage).filter(
  474. AIMessage.id == ai_message.id,
  475. AIMessage.user_id == user.user_id,
  476. ).update({"is_deleted": 1, "updated_at": now_ts})
  477. if ai_message.prev_user_id:
  478. db.query(AIMessage).filter(
  479. AIMessage.id == ai_message.prev_user_id,
  480. AIMessage.user_id == user.user_id,
  481. AIMessage.ai_conversation_id == ai_message.ai_conversation_id,
  482. ).update({"is_deleted": 1, "updated_at": now_ts})
  483. _refresh_conversation_snapshot(
  484. db, ai_message.ai_conversation_id, user.user_id)
  485. db.commit()
  486. return {"statusCode": 200, "msg": "删除成功"}
  487. if not data.ai_conversation_id:
  488. return {"statusCode": 400, "msg": "缺少删除参数"}
  489. db.query(AIConversation).filter(
  490. AIConversation.id == data.ai_conversation_id,
  491. AIConversation.user_id == user.user_id,
  492. ).update({"is_deleted": 1, "updated_at": now_ts})
  493. db.query(AIMessage).filter(
  494. AIMessage.ai_conversation_id == data.ai_conversation_id,
  495. AIMessage.user_id == user.user_id,
  496. ).update({"is_deleted": 1, "updated_at": now_ts})
  497. db.commit()
  498. return {"statusCode": 200, "msg": "删除成功"}
  499. class DeleteHistoryRequest(BaseModel):
  500. ai_conversation_id: int
  501. @router.post("/delete_history_record")
  502. async def delete_history_record(
  503. request: Request, data: DeleteHistoryRequest, db: Session = Depends(get_db)
  504. ):
  505. """删除历史记录(软删除)"""
  506. user = request.state.user
  507. if not user:
  508. return {"statusCode": 401, "msg": "未授权"}
  509. db.query(AIConversation).filter(
  510. AIConversation.id == data.ai_conversation_id,
  511. AIConversation.user_id == user.user_id,
  512. ).update({"is_deleted": 1, "updated_at": int(time.time())})
  513. db.commit()
  514. return {"statusCode": 200, "msg": "删除成功"}
  515. # ─────────────────────────────────────────────────────────────────────────
  516. # 流式接口 /stream/chat(无 DB,意图识别 + RAG)
  517. # ─────────────────────────────────────────────────────────────────────────
  518. class StreamChatRequest(BaseModel):
  519. message: str
  520. model: str = ""
  521. @router.post("/stream/chat")
  522. async def stream_chat(request: Request, data: StreamChatRequest):
  523. """流式聊天(SSE,不写 DB)"""
  524. message = data.message.strip()
  525. if not message:
  526. return JSONResponse(content={"statusCode": 400, "msg": "消息不能为空"})
  527. async def event_generator():
  528. intent_type = ""
  529. try:
  530. intent_result = await qwen_service.intent_recognition(message)
  531. if isinstance(intent_result, dict):
  532. intent_type = (
  533. intent_result.get("intent_type") or intent_result.get(
  534. "intent") or ""
  535. ).lower()
  536. except Exception as ie:
  537. logger.warning(f"[stream/chat] 意图识别异常: {ie}")
  538. rag_context = ""
  539. if intent_type in ("query_knowledge_base", "知识库查询", "技术咨询"):
  540. rag_context = await _rag_search(message)
  541. # 使用prompt加载器加载最终回答prompt
  542. system_content = load_prompt(
  543. "final_answer",
  544. userMessage=message,
  545. contextJSON=rag_context if rag_context else "暂无相关知识库内容"
  546. )
  547. messages = [
  548. {"role": "user", "content": system_content},
  549. ]
  550. try:
  551. async for chunk in qwen_service.stream_chat(messages):
  552. yield f"data: {json.dumps({'content': chunk}, ensure_ascii=False)}\n\n"
  553. except Exception as e:
  554. logger.error(f"[stream/chat] 流式输出异常: {e}")
  555. yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
  556. finally:
  557. yield "data: [DONE]\n\n"
  558. return StreamingResponse(event_generator(), media_type="text/event-stream")
  559. # ─────────────────────────────────────────────────────────────────────────
  560. # 流式接口 /stream/chat-with-db(前端主聊天接口)
  561. # ─────────────────────────────────────────────────────────────────────────
  562. class StreamChatWithDBRequest(BaseModel):
  563. message: str
  564. ai_conversation_id: int = 0
  565. business_type: int = 0
  566. exam_name: str = ""
  567. ai_message_id: int = 0
  568. online_search_content: str = ""
  569. @router.post("/stream/chat-with-db")
  570. async def stream_chat_with_db(request: Request, data: StreamChatWithDBRequest):
  571. """
  572. 带 DB 操作的流式聊天(SSE)
  573. 流程:
  574. 1. 创建/获取对话
  575. 2. 插入用户消息和 AI 占位消息
  576. 3. 发送 initial 事件
  577. 4. RAG 检索
  578. 5. 构建历史上下文
  579. 6. 流式输出
  580. 7. 更新 AI 消息内容
  581. """
  582. user = request.state.user
  583. if not user:
  584. return JSONResponse(content={"statusCode": 401, "msg": "未授权"})
  585. message = data.message.strip()
  586. if not message:
  587. return JSONResponse(content={"statusCode": 400, "msg": "消息不能为空"})
  588. async def event_generator():
  589. db = SessionLocal()
  590. try:
  591. # 1. 创建或获取对话
  592. if data.ai_conversation_id == 0:
  593. conversation = AIConversation(
  594. user_id=user.user_id,
  595. content=_build_conversation_preview(message, limit=100),
  596. business_type=data.business_type,
  597. exam_name=data.exam_name,
  598. created_at=int(time.time()),
  599. updated_at=int(time.time()),
  600. is_deleted=0,
  601. )
  602. db.add(conversation)
  603. db.commit()
  604. db.refresh(conversation)
  605. conv_id = conversation.id
  606. else:
  607. existing_conversation = (
  608. db.query(AIConversation)
  609. .filter(
  610. AIConversation.id == data.ai_conversation_id,
  611. AIConversation.user_id == user.user_id,
  612. AIConversation.is_deleted == 0,
  613. )
  614. .first()
  615. )
  616. if existing_conversation:
  617. conv_id = existing_conversation.id
  618. db.query(AIConversation).filter(
  619. AIConversation.id == conv_id,
  620. AIConversation.user_id == user.user_id,
  621. ).update(
  622. {
  623. "content": _build_conversation_preview(message, limit=100),
  624. "business_type": data.business_type,
  625. "exam_name": data.exam_name if data.business_type == 3 else "",
  626. "updated_at": int(time.time()),
  627. }
  628. )
  629. db.commit()
  630. else:
  631. conversation = AIConversation(
  632. user_id=user.user_id,
  633. content=_build_conversation_preview(
  634. message, limit=100),
  635. business_type=data.business_type,
  636. exam_name=data.exam_name if data.business_type == 3 else "",
  637. created_at=int(time.time()),
  638. updated_at=int(time.time()),
  639. is_deleted=0,
  640. )
  641. db.add(conversation)
  642. db.commit()
  643. db.refresh(conversation)
  644. conv_id = conversation.id
  645. # 2. 插入用户消息
  646. user_msg = AIMessage(
  647. ai_conversation_id=conv_id,
  648. user_id=user.user_id,
  649. type="user",
  650. content=message,
  651. created_at=int(time.time()),
  652. updated_at=int(time.time()),
  653. is_deleted=0,
  654. )
  655. db.add(user_msg)
  656. db.commit()
  657. db.refresh(user_msg)
  658. # 3. 插入 AI 占位消息
  659. ai_msg = AIMessage(
  660. ai_conversation_id=conv_id,
  661. user_id=user.user_id,
  662. type="ai",
  663. content="",
  664. prev_user_id=user_msg.id,
  665. created_at=int(time.time()),
  666. updated_at=int(time.time()),
  667. is_deleted=0,
  668. )
  669. db.add(ai_msg)
  670. db.commit()
  671. db.refresh(ai_msg)
  672. # 4. 发送 initial 事件
  673. yield f"data: {json.dumps({'type': 'initial', 'ai_conversation_id': conv_id, 'ai_message_id': ai_msg.id}, ensure_ascii=False)}\n\n"
  674. # 5. RAG 检索
  675. rag_context = await _rag_search(message, top_k=10)
  676. # 6. 获取历史上下文(最近 4 条,2 轮对话)
  677. history_msgs = (
  678. db.query(AIMessage)
  679. .filter(
  680. AIMessage.ai_conversation_id == conv_id,
  681. AIMessage.id < ai_msg.id,
  682. AIMessage.is_deleted == 0,
  683. )
  684. .order_by(AIMessage.updated_at.desc())
  685. .limit(4)
  686. .all()
  687. )
  688. history_msgs.reverse()
  689. history_context = ""
  690. for msg in history_msgs:
  691. role = "用户" if msg.type == "user" else "助手"
  692. history_context += f"{role}: {msg.content}\n\n"
  693. # 7. 构建完整 prompt
  694. # 构建上下文JSON
  695. context_parts = []
  696. if rag_context:
  697. context_parts.append(f"知识库内容:\n{rag_context}")
  698. if data.online_search_content:
  699. context_parts.append(f"联网搜索结果:\n{data.online_search_content}")
  700. context_json = "\n\n".join(
  701. context_parts) if context_parts else "暂无相关知识库内容"
  702. # 使用prompt加载器加载最终回答prompt
  703. system_content = load_prompt(
  704. "final_answer",
  705. userMessage=message,
  706. contextJSON=context_json,
  707. historyContext=history_context if history_context else ""
  708. )
  709. messages = [
  710. {"role": "user", "content": system_content},
  711. ]
  712. # 8. 流式输出并收集完整回复
  713. full_response = ""
  714. try:
  715. async for chunk in qwen_service.stream_chat(messages):
  716. escaped_chunk = chunk.replace("\n", "\\n")
  717. full_response += chunk
  718. yield f"data: {escaped_chunk}\n\n"
  719. except Exception as e:
  720. logger.error(f"[stream/chat-with-db] 流式输出异常: {e}")
  721. yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
  722. # 9. 更新 AI 消息内容
  723. if full_response:
  724. now_ts = int(time.time())
  725. db.query(AIMessage).filter(AIMessage.id == ai_msg.id).update(
  726. {"content": full_response, "updated_at": now_ts}
  727. )
  728. db.query(AIConversation).filter(
  729. AIConversation.id == conv_id,
  730. AIConversation.user_id == user.user_id,
  731. ).update(
  732. {
  733. "content": _build_conversation_preview(message, limit=100),
  734. "business_type": data.business_type,
  735. "exam_name": data.exam_name if data.business_type == 3 else "",
  736. "updated_at": now_ts,
  737. }
  738. )
  739. db.commit()
  740. # 10. 结束标记
  741. yield "data: [DONE]\n\n"
  742. except Exception as e:
  743. logger.error(f"[stream/chat-with-db] 处理异常: {e}")
  744. yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
  745. finally:
  746. db.close()
  747. return StreamingResponse(event_generator(), media_type="text/event-stream")
  748. # ─────────────────────────────────────────────────────────────────────────
  749. # 猜你想问
  750. # ─────────────────────────────────────────────────────────────────────────
  751. class GuessYouWantRequest(BaseModel):
  752. ai_message_id: int
  753. @router.post("/guess_you_want")
  754. async def guess_you_want(
  755. request: Request,
  756. data: GuessYouWantRequest,
  757. db: Session = Depends(get_db),
  758. ):
  759. """生成"猜你想问"的3个关联问题,保存到 AIMessage.guess_you_want"""
  760. user = request.state.user
  761. if not user:
  762. return {"statusCode": 401, "msg": "未授权"}
  763. try:
  764. ai_msg = (
  765. db.query(AIMessage)
  766. .filter(AIMessage.id == data.ai_message_id, AIMessage.is_deleted == 0)
  767. .first()
  768. )
  769. if not ai_msg:
  770. return {"statusCode": 404, "msg": "消息不存在"}
  771. # 使用prompt加载器加载猜你想问prompt
  772. system_content = load_prompt(
  773. "guess_questions",
  774. currentContent=ai_msg.content[:500]
  775. )
  776. messages = [
  777. {"role": "user", "content": system_content},
  778. ]
  779. response = await qwen_service.chat(messages)
  780. try:
  781. # 尝试从响应中提取 JSON
  782. import re
  783. json_match = re.search(
  784. r'\{[^{}]*"questions"[^{}]*\}', response, re.DOTALL)
  785. if json_match:
  786. response_json = json.loads(json_match.group())
  787. else:
  788. response_json = json.loads(response)
  789. questions = response_json.get("questions", [])
  790. except Exception:
  791. lines = [l.strip() for l in response.split("\n") if l.strip()]
  792. questions = []
  793. for line in lines:
  794. clean = line.lstrip("0123456789.-、 ").strip()
  795. if clean and len(clean) > 5:
  796. questions.append(clean)
  797. if not questions:
  798. questions = ["该话题的具体应用场景?", "有哪些注意事项?", "相关案例分析?"]
  799. questions = questions[:3]
  800. while len(questions) < 3:
  801. questions.append("更多相关问题")
  802. guess_json = json.dumps({"questions": questions}, ensure_ascii=False)
  803. db.query(AIMessage).filter(AIMessage.id == data.ai_message_id).update(
  804. {"guess_you_want": guess_json, "updated_at": int(time.time())}
  805. )
  806. db.commit()
  807. return {
  808. "statusCode": 200,
  809. "msg": "success",
  810. "data": {"ai_message_id": data.ai_message_id, "questions": questions},
  811. }
  812. except Exception as e:
  813. logger.error(f"[guess_you_want] 处理异常: {e}")
  814. return {"statusCode": 500, "msg": f"处理失败: {str(e)}"}
  815. # ─────────────────────────────────────────────────────────────────────────
  816. # 在线搜索(Dify 工作流集成)
  817. # ─────────────────────────────────────────────────────────────────────────
  818. @router.get("/online_search")
  819. async def online_search(question: str, request: Request, db: Session = Depends(get_db)):
  820. """
  821. 在线搜索
  822. 流程:Qwen 提炼关键词 → Dify 工作流 → 返回摘要
  823. """
  824. user = request.state.user
  825. if not user:
  826. return {"statusCode": 401, "msg": "未授权"}
  827. try:
  828. keywords = await qwen_service.extract_keywords(question)
  829. dify_config = getattr(settings, "dify", None)
  830. if not dify_config or not getattr(dify_config, "workflow_url", None):
  831. return {"statusCode": 500, "msg": "Dify 配置未设置"}
  832. headers = {
  833. "Authorization": f"Bearer {dify_config.auth_token}",
  834. "Content-Type": "application/json",
  835. }
  836. payload = {
  837. "workflow_id": dify_config.workflow_id,
  838. "inputs": {
  839. "keywords": keywords,
  840. "num": 5, # 搜索结果数量
  841. "max_text_len": 4000 # 最大文本长度
  842. },
  843. "response_mode": "blocking",
  844. "user": getattr(user, "account", str(user.user_id)),
  845. }
  846. async with httpx.AsyncClient(timeout=30.0) as client:
  847. resp = await client.post(dify_config.workflow_url, headers=headers, json=payload)
  848. if resp.status_code != 200:
  849. logger.error(
  850. f"[online_search] Dify 调用失败: {resp.status_code}, 响应: {resp.text}")
  851. return {"statusCode": 500, "msg": f"搜索服务异常: {resp.status_code}"}
  852. result = resp.json()
  853. search_text = result.get("data", {}).get(
  854. "outputs", {}).get("text", "")
  855. return {
  856. "statusCode": 200,
  857. "msg": "success",
  858. "data": {"keywords": keywords, "result": search_text},
  859. }
  860. except Exception as e:
  861. logger.error(f"[online_search] 处理异常: {e}")
  862. return {"statusCode": 500, "msg": f"搜索失败: {str(e)}"}
  863. class SaveOnlineSearchResultRequest(BaseModel):
  864. ai_message_id: int
  865. search_result: str
  866. @router.post("/save_online_search_result")
  867. async def save_online_search_result(
  868. request: Request,
  869. data: SaveOnlineSearchResultRequest,
  870. db: Session = Depends(get_db),
  871. ):
  872. """保存联网搜索结果到 AIMessage.search_source"""
  873. user = request.state.user
  874. if not user:
  875. return {"statusCode": 401, "msg": "未授权"}
  876. try:
  877. db.query(AIMessage).filter(AIMessage.id == data.ai_message_id).update(
  878. {"search_source": data.search_result,
  879. "updated_at": int(time.time())}
  880. )
  881. db.commit()
  882. return {"statusCode": 200, "msg": "保存成功"}
  883. except Exception as e:
  884. logger.error(f"[save_online_search_result] 处理异常: {e}")
  885. return {"statusCode": 500, "msg": f"保存失败: {str(e)}"}
  886. # ─────────────────────────────────────────────────────────────────────────
  887. # 意图识别独立接口
  888. # ─────────────────────────────────────────────────────────────────────────
  889. class IntentRecognitionRequest(BaseModel):
  890. message: str
  891. save_to_db: bool = False
  892. ai_conversation_id: int = 0
  893. @router.post("/intent_recognition")
  894. async def intent_recognition(
  895. request: Request,
  896. data: IntentRecognitionRequest,
  897. db: Session = Depends(get_db),
  898. ):
  899. """独立意图识别接口;若为 greeting/faq 且 save_to_db=True 则直接存 DB"""
  900. user = request.state.user
  901. if not user:
  902. return {"statusCode": 401, "msg": "未授权"}
  903. try:
  904. intent_result = await qwen_service.intent_recognition(data.message)
  905. intent_type = ""
  906. response_text = ""
  907. if isinstance(intent_result, dict):
  908. intent_type = (
  909. intent_result.get("intent_type") or intent_result.get(
  910. "intent") or ""
  911. ).lower()
  912. response_text = intent_result.get("response", "")
  913. if data.save_to_db and intent_type in ("greeting", "问候", "faq", "常见问题"):
  914. if data.ai_conversation_id == 0:
  915. conversation = AIConversation(
  916. user_id=user.user_id,
  917. content=data.message[:100],
  918. business_type=0,
  919. created_at=int(time.time()),
  920. updated_at=int(time.time()),
  921. is_deleted=0,
  922. )
  923. db.add(conversation)
  924. db.commit()
  925. db.refresh(conversation)
  926. conv_id = conversation.id
  927. else:
  928. conv_id = data.ai_conversation_id
  929. user_msg = AIMessage(
  930. ai_conversation_id=conv_id,
  931. user_id=user.user_id,
  932. type="user",
  933. content=data.message,
  934. created_at=int(time.time()),
  935. updated_at=int(time.time()),
  936. is_deleted=0,
  937. )
  938. db.add(user_msg)
  939. db.commit()
  940. ai_msg = AIMessage(
  941. ai_conversation_id=conv_id,
  942. user_id=user.user_id,
  943. type="ai",
  944. content=response_text,
  945. prev_user_id=user_msg.id,
  946. created_at=int(time.time()),
  947. updated_at=int(time.time()),
  948. is_deleted=0,
  949. )
  950. db.add(ai_msg)
  951. db.commit()
  952. db.refresh(ai_msg)
  953. return {
  954. "statusCode": 200,
  955. "msg": "success",
  956. "data": {
  957. "intent_type": intent_type,
  958. "response": response_text,
  959. "ai_conversation_id": conv_id,
  960. "ai_message_id": ai_msg.id,
  961. "saved_to_db": True,
  962. },
  963. }
  964. return {
  965. "statusCode": 200,
  966. "msg": "success",
  967. "data": {
  968. "intent_type": intent_type,
  969. "response": response_text,
  970. "saved_to_db": False,
  971. },
  972. }
  973. except Exception as e:
  974. logger.error(f"[intent_recognition] 处理异常: {e}")
  975. return {"statusCode": 500, "msg": f"处理失败: {str(e)}"}
  976. # ─────────────────────────────────────────────────────────────────────────
  977. # 获取用户推荐问题(模糊查询 QA / RecommendQuestion 表)
  978. # ─────────────────────────────────────────────────────────────────────────
  979. @router.get("/get_user_recommend_question")
  980. async def get_user_recommend_question(
  981. keyword: str = "",
  982. limit: int = 10,
  983. db: Session = Depends(get_db),
  984. ):
  985. """获取推荐问题(支持模糊查询)"""
  986. try:
  987. query = db.query(RecommendQuestion).filter(
  988. RecommendQuestion.is_deleted == 0)
  989. if keyword:
  990. query = query.filter(
  991. RecommendQuestion.question.like(f"%{keyword}%"))
  992. questions = query.order_by(
  993. RecommendQuestion.id.desc()).limit(limit).all()
  994. return {
  995. "statusCode": 200,
  996. "msg": "success",
  997. "data": [
  998. {"id": q.id, "question": q.question, "created_at": q.created_at}
  999. for q in questions
  1000. ],
  1001. }
  1002. except Exception as e:
  1003. logger.error(f"[get_user_recommend_question] 处理异常: {e}")
  1004. return {"statusCode": 500, "msg": f"查询失败: {str(e)}"}
  1005. # ─────────────────────────────────────────────────────────────────────────
  1006. # PPT 大纲 / 文档编辑保存
  1007. # ─────────────────────────────────────────────────────────────────────────
  1008. class SavePPTOutlineRequest(BaseModel):
  1009. ai_message_id: int
  1010. content: str
  1011. @router.post("/save_ppt_outline")
  1012. async def save_ppt_outline(
  1013. request: Request,
  1014. data: SavePPTOutlineRequest,
  1015. db: Session = Depends(get_db),
  1016. ):
  1017. """更新 AIMessage.content 保存 PPT 大纲内容"""
  1018. user = request.state.user
  1019. if not user:
  1020. return {"statusCode": 401, "msg": "未授权"}
  1021. try:
  1022. db.query(AIMessage).filter(AIMessage.id == data.ai_message_id).update(
  1023. {"content": data.content, "updated_at": int(time.time())}
  1024. )
  1025. db.commit()
  1026. return {"statusCode": 200, "msg": "保存成功"}
  1027. except Exception as e:
  1028. logger.error(f"[save_ppt_outline] 处理异常: {e}")
  1029. return {"statusCode": 500, "msg": f"保存失败: {str(e)}"}
  1030. class SaveEditDocumentRequest(BaseModel):
  1031. ai_message_id: int
  1032. content: str
  1033. @router.post("/save_edit_document")
  1034. async def save_edit_document(
  1035. request: Request,
  1036. data: SaveEditDocumentRequest,
  1037. db: Session = Depends(get_db),
  1038. ):
  1039. """更新 ai 类型 AIMessage.content(AI写作编辑保存)"""
  1040. user = request.state.user
  1041. if not user:
  1042. return {"statusCode": 401, "msg": "未授权"}
  1043. try:
  1044. db.query(AIMessage).filter(
  1045. AIMessage.id == data.ai_message_id,
  1046. AIMessage.type == "ai",
  1047. ).update({"content": data.content, "updated_at": int(time.time())})
  1048. db.commit()
  1049. return {"statusCode": 200, "msg": "保存成功"}
  1050. except Exception as e:
  1051. logger.error(f"[save_edit_document] 处理异常: {e}")
  1052. return {"statusCode": 500, "msg": f"保存失败: {str(e)}"}