chat.py 81 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208
  1. from fastapi import APIRouter, Depends, Request
  2. from fastapi.responses import StreamingResponse, JSONResponse
  3. from sqlalchemy.orm import Session
  4. from pydantic import BaseModel
  5. from typing import Optional
  6. from database import get_db, SessionLocal
  7. from models.chat import AIConversation, AIMessage
  8. from models.total import RecommendQuestion
  9. from utils.config import settings
  10. from utils.logger import logger
  11. from services.qwen_service import qwen_service
  12. from services.deepseek_service import deepseek_service
  13. from utils.prompt_loader import load_prompt
  14. from utils.thinking_summary import split_thinking_and_answer, summarize_thinking_content
  15. import time
  16. import json
  17. import httpx
  18. import re
  19. router = APIRouter()
  20. def _build_conversation_preview(content: str, limit: int = 50) -> str:
  21. content = (content or "").strip()
  22. if len(content) <= limit:
  23. return content
  24. return content[:limit] + "..."
  25. def _to_frontend_timestamp(timestamp: Optional[int]) -> Optional[int]:
  26. if not timestamp:
  27. return None
  28. return timestamp if timestamp >= 10**12 else timestamp * 1000
  29. def _build_conversation_title(conversation: AIConversation) -> str:
  30. if conversation.business_type == 3 and (conversation.exam_name or "").strip():
  31. return conversation.exam_name.strip()
  32. return _build_conversation_preview(conversation.content or "", limit=30)
  33. def _extract_json_object_from_index(source: str, start_idx: int) -> str:
  34. if start_idx < 0 or start_idx >= len(source) or source[start_idx] != "{":
  35. return ""
  36. depth = 0
  37. in_string = False
  38. escaped = False
  39. for idx in range(start_idx, len(source)):
  40. ch = source[idx]
  41. if escaped:
  42. escaped = False
  43. continue
  44. if in_string:
  45. if ch == "\\":
  46. escaped = True
  47. elif ch == '"':
  48. in_string = False
  49. continue
  50. if ch == '"':
  51. in_string = True
  52. continue
  53. if ch == "{":
  54. depth += 1
  55. elif ch == "}":
  56. depth -= 1
  57. if depth == 0:
  58. return source[start_idx: idx + 1]
  59. return ""
  60. def _extract_balanced_json_objects(text: str) -> list[str]:
  61. source = (text or "").strip()
  62. if not source:
  63. return []
  64. objects = []
  65. seen = set()
  66. for idx, ch in enumerate(source):
  67. if ch != "{":
  68. continue
  69. candidate = _extract_json_object_from_index(source, idx)
  70. if candidate and candidate not in seen:
  71. objects.append(candidate)
  72. seen.add(candidate)
  73. return objects
  74. def _extract_trailing_json_candidates(text: str) -> list[str]:
  75. source = (text or "").strip()
  76. if not source:
  77. return []
  78. candidates = []
  79. seen = set()
  80. line_start_indexes = [
  81. match.start()
  82. for match in re.finditer(r"(?m)^[ \t]*\{", source)
  83. ]
  84. for start_idx in reversed(line_start_indexes):
  85. candidate = source[start_idx:].strip()
  86. if candidate and candidate not in seen:
  87. candidates.append(candidate)
  88. seen.add(candidate)
  89. return candidates
  90. def _extract_explicit_answer_segment(text: str) -> str:
  91. source = (text or "").strip()
  92. if not source:
  93. return ""
  94. markers = (
  95. "final answer:",
  96. "final output:",
  97. "answer:",
  98. "output:",
  99. "json:",
  100. )
  101. lowered = source.lower()
  102. for marker in markers:
  103. idx = lowered.rfind(marker)
  104. if idx >= 0:
  105. candidate = source[idx + len(marker):].strip()
  106. if candidate:
  107. return candidate
  108. return ""
  109. def _extract_brace_sliced_candidates(text: str) -> list[str]:
  110. source = (text or "").strip()
  111. if not source:
  112. return []
  113. candidates = []
  114. seen = set()
  115. first_brace = source.find("{")
  116. last_brace = source.rfind("}")
  117. if first_brace >= 0 and last_brace > first_brace:
  118. candidate = source[first_brace:last_brace + 1].strip()
  119. if candidate and candidate not in seen:
  120. candidates.append(candidate)
  121. seen.add(candidate)
  122. return candidates
  123. def _looks_like_exam_payload(payload: object) -> bool:
  124. if not isinstance(payload, dict):
  125. return False
  126. questions = payload.get("questions")
  127. return any(
  128. key in payload
  129. for key in (
  130. "singleChoice",
  131. "single_choice",
  132. "单选题",
  133. "judge",
  134. "判断题",
  135. "multiple",
  136. "multiple_choice",
  137. "multipleChoice",
  138. "多选题",
  139. "short",
  140. "short_answer",
  141. "shortAnswer",
  142. "简答题",
  143. )
  144. ) or (
  145. isinstance(questions, dict)
  146. and any(
  147. key in questions
  148. for key in (
  149. "singleChoice",
  150. "single_choice",
  151. "单选题",
  152. "judge",
  153. "判断题",
  154. "multiple",
  155. "multiple_choice",
  156. "multipleChoice",
  157. "多选题",
  158. "short",
  159. "short_answer",
  160. "shortAnswer",
  161. "简答题",
  162. )
  163. )
  164. )
  165. def _score_exam_payload_candidate(payload: object) -> int:
  166. if not isinstance(payload, dict):
  167. return 0
  168. score = 0
  169. questions = payload.get("questions") if isinstance(
  170. payload.get("questions"), dict) else {}
  171. strong_keys = (
  172. "singleChoice",
  173. "single_choice",
  174. "单选题",
  175. "judge",
  176. "判断题",
  177. "multiple",
  178. "multiple_choice",
  179. "multipleChoice",
  180. "多选题",
  181. "short",
  182. "short_answer",
  183. "shortAnswer",
  184. "简答题",
  185. )
  186. weak_keys = (
  187. "title",
  188. "exam_name",
  189. "examTitle",
  190. "试卷标题",
  191. "总分",
  192. "totalScore",
  193. "totalQuestions",
  194. )
  195. score += sum(10 for key in strong_keys if key in payload)
  196. score += sum(8 for key in strong_keys if key in questions)
  197. score += sum(2 for key in weak_keys if key in payload)
  198. section_candidates = []
  199. for _, value in payload.items():
  200. if isinstance(value, dict):
  201. section_candidates.append(value)
  202. section_candidates.extend(
  203. value for value in questions.values() if isinstance(value, dict))
  204. for section in section_candidates:
  205. if "questions" in section and isinstance(section.get("questions"), list):
  206. score += 6
  207. question_list = section.get("questions") or []
  208. if question_list and isinstance(question_list[0], dict):
  209. first_question = question_list[0]
  210. if any(k in first_question for k in ("text", "question_text", "question", "title", "content", "题干", "题目")):
  211. score += 4
  212. if "options" in first_question:
  213. score += 3
  214. if any(k in first_question for k in ("answer", "answers", "correct_answer", "correct_answers", "答案", "正确答案")):
  215. score += 3
  216. if any(k in first_question for k in ("analysis", "explanation", "解析")):
  217. score += 2
  218. if any(k in section for k in ("count", "question_count", "数量")):
  219. score += 2
  220. if any(k in section for k in ("scorePerQuestion", "score_per_question", "每题分值")):
  221. score += 1
  222. return score
  223. def _escape_inner_quotes_in_json(text: str) -> str:
  224. chars = []
  225. in_string = False
  226. escaped = False
  227. for idx, ch in enumerate(text):
  228. if not in_string:
  229. chars.append(ch)
  230. if ch == '"':
  231. in_string = True
  232. escaped = False
  233. continue
  234. if escaped:
  235. chars.append(ch)
  236. escaped = False
  237. continue
  238. if ch == "\\":
  239. chars.append(ch)
  240. escaped = True
  241. continue
  242. if ch == '"':
  243. next_non_space = ""
  244. for next_idx in range(idx + 1, len(text)):
  245. if not text[next_idx].isspace():
  246. next_non_space = text[next_idx]
  247. break
  248. if next_non_space in {",", "}", "]", ":"}:
  249. chars.append(ch)
  250. in_string = False
  251. else:
  252. chars.append('\\"')
  253. continue
  254. chars.append(ch)
  255. return "".join(chars)
  256. def _try_parse_exam_json(candidate: str) -> Optional[dict]:
  257. text = (candidate or "").strip()
  258. if not text:
  259. return None
  260. text = (
  261. text.replace("\ufeff", "")
  262. .replace("```json", "")
  263. .replace("```JSON", "")
  264. .replace("```", "")
  265. .replace("“", '"')
  266. .replace("”", '"')
  267. ).strip()
  268. try:
  269. parsed = json.loads(text)
  270. except Exception:
  271. repaired_text = _escape_inner_quotes_in_json(text)
  272. repaired_text = re.sub(r",\s*([}\]])", r"\1", repaired_text)
  273. try:
  274. parsed = json.loads(repaired_text)
  275. except Exception:
  276. return None
  277. return parsed if _looks_like_exam_payload(parsed) else None
  278. def _sanitize_exam_response(raw_response: str) -> str:
  279. """考试工坊只向前端/数据库透传可 JSON.parse 的试卷 JSON。"""
  280. raw_text = (raw_response or "").strip()
  281. if not raw_text:
  282. return ""
  283. _, answer = split_thinking_and_answer(raw_text)
  284. explicit_answer = _extract_explicit_answer_segment(raw_text)
  285. for candidate in (answer, explicit_answer, raw_text):
  286. parsed = _try_parse_exam_json(candidate)
  287. if parsed:
  288. return json.dumps(parsed, ensure_ascii=False)
  289. parsed_candidates = []
  290. for candidate in _extract_balanced_json_objects(raw_text):
  291. parsed = _try_parse_exam_json(candidate)
  292. if parsed:
  293. parsed_candidates.append((parsed, candidate))
  294. for candidate in _extract_trailing_json_candidates(raw_text):
  295. parsed = _try_parse_exam_json(candidate)
  296. if parsed:
  297. parsed_candidates.append((parsed, candidate))
  298. for candidate in _extract_brace_sliced_candidates(raw_text):
  299. parsed = _try_parse_exam_json(candidate)
  300. if parsed:
  301. parsed_candidates.append((parsed, candidate))
  302. if parsed_candidates:
  303. parsed_candidates.sort(
  304. key=lambda item: (
  305. _score_exam_payload_candidate(item[0]),
  306. len(json.dumps(item[0], ensure_ascii=False)),
  307. ),
  308. reverse=True,
  309. )
  310. best_payload, best_raw_candidate = parsed_candidates[0]
  311. if _score_exam_payload_candidate(best_payload) > 0:
  312. return json.dumps(best_payload, ensure_ascii=False)
  313. logger.warning(
  314. "[exam] 已提取到JSON对象但试卷特征较弱,选择最大候选兜底: score=%s snippet=%s",
  315. _score_exam_payload_candidate(best_payload),
  316. (best_raw_candidate or "")[:200],
  317. )
  318. return json.dumps(best_payload, ensure_ascii=False)
  319. logger.warning("[exam] 未能从模型响应中提取试卷 JSON,保留原始响应供前端兜底解析")
  320. return raw_text
  321. def _normalize_related_question(question: str) -> str:
  322. if not isinstance(question, str):
  323. return ""
  324. text = question.strip().strip('"').strip("'")
  325. text = re.sub(r"^[0-9]+[\.\)\]、]\s*", "", text)
  326. text = re.sub(r"^[-*]\s*", "", text)
  327. return text.strip()
  328. def _is_placeholder_related_question(question: str) -> bool:
  329. normalized = _normalize_related_question(question).lower()
  330. if not normalized:
  331. return True
  332. placeholder_patterns = (
  333. r"^q\s*\d+$",
  334. r"^question\s*\d+$",
  335. r"^questions?\s*\d+$",
  336. r"^问题\s*\d+$",
  337. r"^相关问题\s*\d+$",
  338. r"^推荐问题\s*\d+$",
  339. r"^更多相关问题$",
  340. r"^更多问题$",
  341. )
  342. return any(re.fullmatch(pattern, normalized) for pattern in placeholder_patterns)
  343. def _contains_chinese(text: str) -> bool:
  344. return any("\u4e00" <= char <= "\u9fff" for char in text or "")
  345. def _is_invalid_related_question(question: str) -> bool:
  346. normalized = _normalize_related_question(question)
  347. if (
  348. not normalized
  349. or len(normalized) < 4
  350. or _is_placeholder_related_question(normalized)
  351. or not _contains_chinese(normalized)
  352. ):
  353. return True
  354. lowered = normalized.lower()
  355. blocked_keywords = (
  356. "thinking process",
  357. "analyze the request",
  358. "role:",
  359. "**role",
  360. "professional question recommendation",
  361. "infrastructure construction technology",
  362. "output format",
  363. "json",
  364. "prompt",
  365. "system",
  366. "assistant",
  367. "角色定义",
  368. "任务目标",
  369. "输入内容",
  370. "生成要求",
  371. "输出格式",
  372. "开始生成",
  373. )
  374. return any(keyword in lowered for keyword in blocked_keywords)
  375. def _extract_related_question_topic(content: str) -> str:
  376. if not content:
  377. return "当前话题"
  378. text = re.sub(r"<[^>]+>", " ", str(content))
  379. text = re.sub(r"\s+", " ", text).strip()
  380. text = re.sub(
  381. r"^(好的[!!,, ]*|我理解您提出的问题[,, ]*|这个问题[,, ]*|总的来说[::,, ]*)+",
  382. "",
  383. text,
  384. )
  385. pattern = re.search(
  386. r"(?:主要围绕|围绕|关于|针对|聚焦)([^。!?\n,,;;]{4,32})",
  387. text,
  388. )
  389. if pattern:
  390. topic = pattern.group(1).strip("“”\"' ::,,")
  391. if topic:
  392. return topic
  393. sentence = re.split(r"[。!?\n]", text, maxsplit=1)[0].strip("“”\"' ::,,")
  394. if sentence:
  395. return sentence[:24]
  396. return "当前话题"
  397. def _build_related_question_fallbacks(content: str) -> list[str]:
  398. topic = _extract_related_question_topic(content)
  399. return [
  400. f"{topic}在现场实施时需要重点关注哪些风险点?",
  401. f"{topic}相关的方案编制、审批和验收要求有哪些?",
  402. f"针对{topic},日常检查和监测应抓住哪些关键指标?",
  403. ]
  404. def _finalize_related_questions(questions: list, content: str, limit: int = 3) -> list[str]:
  405. cleaned_questions = []
  406. seen = set()
  407. for question in questions or []:
  408. normalized = _normalize_related_question(question)
  409. lowered = normalized.lower()
  410. if (
  411. _is_invalid_related_question(normalized)
  412. or lowered in seen
  413. ):
  414. continue
  415. cleaned_questions.append(normalized)
  416. seen.add(lowered)
  417. if len(cleaned_questions) == limit:
  418. return cleaned_questions
  419. for fallback in _build_related_question_fallbacks(content):
  420. lowered = fallback.lower()
  421. if lowered in seen:
  422. continue
  423. cleaned_questions.append(fallback)
  424. seen.add(lowered)
  425. if len(cleaned_questions) == limit:
  426. break
  427. return cleaned_questions[:limit]
  428. def _refresh_conversation_snapshot(db: Session, conversation_id: int, user_id: int) -> None:
  429. latest_message = (
  430. db.query(AIMessage)
  431. .filter(
  432. AIMessage.ai_conversation_id == conversation_id,
  433. AIMessage.user_id == user_id,
  434. AIMessage.is_deleted == 0,
  435. )
  436. .order_by(AIMessage.id.desc())
  437. .first()
  438. )
  439. if not latest_message:
  440. db.query(AIConversation).filter(
  441. AIConversation.id == conversation_id,
  442. AIConversation.user_id == user_id,
  443. ).update({"is_deleted": 1, "updated_at": int(time.time())})
  444. return
  445. latest_user_message = (
  446. db.query(AIMessage)
  447. .filter(
  448. AIMessage.ai_conversation_id == conversation_id,
  449. AIMessage.user_id == user_id,
  450. AIMessage.type == "user",
  451. AIMessage.is_deleted == 0,
  452. )
  453. .order_by(AIMessage.id.desc())
  454. .first()
  455. )
  456. preview_source = (
  457. latest_user_message.content
  458. if latest_user_message and latest_user_message.content
  459. else latest_message.content
  460. )
  461. preview_content = _build_conversation_preview(
  462. preview_source or "", limit=100)
  463. db.query(AIConversation).filter(
  464. AIConversation.id == conversation_id,
  465. AIConversation.user_id == user_id,
  466. ).update(
  467. {
  468. "content": preview_content or " ",
  469. "updated_at": int(time.time()),
  470. }
  471. )
  472. # ─────────────────────────────────────────────────────────────────────────
  473. # 辅助函数
  474. # ─────────────────────────────────────────────────────────────────────────
  475. async def _rag_search(message: str, top_k: int = 5) -> str:
  476. """调用 search API 做 RAG 检索,返回上下文文本"""
  477. try:
  478. search_cfg = getattr(settings, 'search', None)
  479. if not search_cfg or not hasattr(search_cfg, 'api_url'):
  480. return ""
  481. search_url = search_cfg.api_url
  482. if not search_url:
  483. return ""
  484. async with httpx.AsyncClient(timeout=10.0) as client:
  485. resp = await client.post(
  486. search_url,
  487. json={"query": message, "n_results": top_k},
  488. )
  489. if resp.status_code == 200:
  490. data = resp.json()
  491. docs = data.get("results") or data.get("documents") or []
  492. return "\n\n".join(
  493. d.get("content") or d.get("text") or str(d)
  494. for d in docs[:top_k]
  495. if d.get("content") or d.get("text")
  496. )
  497. except Exception as e:
  498. logger.warning(f"[RAG] 检索失败(可忽略): {e}")
  499. return ""
  500. SAFETY_TRAINING_PLAN_SYSTEM_PROMPT = """
  501. 你是安全培训需求整理助手。请把用户的自然语言输入整理成安全培训PPT大纲生成任务。
  502. 规则:
  503. 1. 只输出一个 JSON 对象,不要输出 Markdown、解释或额外文字。
  504. 2. 即使用户说“通知”“材料”“文档”,也必须理解为安全培训模块中的 PPT 大纲需求,不要切换到其他文档生成任务。
  505. 3. 如果字段缺失,请根据安全培训场景合理补全,但不要编造具体制度编号、人员姓名或不存在的事实。
  506. 4. template 字段用于选择大纲模板,默认填“标准安全培训PPT大纲”。
  507. 5. content_focus 至少给出 3 个要点。
  508. JSON 字段:
  509. {
  510. "topic": "培训主题",
  511. "template": "模板名称",
  512. "content_focus": ["内容要点1", "内容要点2", "内容要点3"],
  513. "audience": "参训对象",
  514. "time": "培训时间",
  515. "location": "培训地点",
  516. "goal": "培训目标",
  517. "notes": "其他要求",
  518. "normalized_request": "归一化后的安全培训PPT大纲生成需求"
  519. }
  520. """
  521. def _extract_tag_value(message: str, tag: str) -> str:
  522. match = re.search(fr"<{tag}>(.*?)</{tag}>", message or "", re.DOTALL)
  523. return match.group(1).strip() if match else ""
  524. def _strip_document_tags(message: str) -> str:
  525. text = message or ""
  526. for tag in ("word", "filename", "filesize"):
  527. text = re.sub(fr"<{tag}>.*?</{tag}>", " ", text, flags=re.DOTALL)
  528. return re.sub(r"\s+", " ", text).strip()
  529. def _extract_safety_training_request_payload(message: str) -> dict:
  530. return {
  531. "document_content": _extract_tag_value(message, "word"),
  532. "filename": _extract_tag_value(message, "filename"),
  533. "filesize": _extract_tag_value(message, "filesize"),
  534. "request": _strip_document_tags(message),
  535. }
  536. def _clean_safety_training_topic(message: str) -> str:
  537. request_text = _extract_safety_training_request_payload(message)["request"]
  538. first_clause = re.split(r"[,。;;,\n]", request_text, maxsplit=1)[0].strip()
  539. topic = first_clause or request_text or "安全培训"
  540. for token in ("请", "帮我", "帮忙", "生成", "制作", "输出", "一份", "一个", "一下", "PPT大纲", "ppt大纲", "大纲", "通知", "文档", "材料"):
  541. topic = topic.replace(token, "")
  542. topic = re.sub(r"\s+", "", topic).strip(" ::,,。;;")
  543. if not topic:
  544. topic = "安全培训"
  545. if "培训" not in topic:
  546. topic = f"{topic}安全培训"
  547. return topic
  548. def _parse_json_object(text: str) -> dict:
  549. if not text:
  550. return {}
  551. cleaned = re.sub(r"```(?:json)?\s*", "", str(text)
  552. ).replace("```", "").strip()
  553. match = re.search(r"\{.*\}", cleaned, re.DOTALL)
  554. if not match:
  555. return {}
  556. try:
  557. parsed = json.loads(match.group(0))
  558. return parsed if isinstance(parsed, dict) else {}
  559. except json.JSONDecodeError:
  560. return {}
  561. def _build_fallback_safety_training_plan(message: str) -> dict:
  562. topic = _clean_safety_training_topic(message)
  563. payload = _extract_safety_training_request_payload(message)
  564. return {
  565. "topic": topic,
  566. "template": "标准安全培训PPT大纲",
  567. "content_focus": ["安全生产责任", "现场风险识别", "安全意识提升", "培训纪律与行为规范"],
  568. "audience": "参训员工",
  569. "time": "",
  570. "location": "",
  571. "goal": "提升参训人员安全意识和施工现场风险防控能力",
  572. "notes": payload["request"],
  573. "normalized_request": f"围绕{topic}生成安全培训PPT大纲",
  574. }
  575. def _normalize_safety_training_plan(message: str, raw_plan: dict) -> dict:
  576. plan = _build_fallback_safety_training_plan(message)
  577. if not isinstance(raw_plan, dict):
  578. return plan
  579. for key in ("topic", "template", "audience", "time", "location", "goal", "notes", "normalized_request"):
  580. value = raw_plan.get(key)
  581. if isinstance(value, str) and value.strip():
  582. plan[key] = value.strip()
  583. focus = raw_plan.get("content_focus")
  584. if isinstance(focus, list):
  585. normalized_focus = [str(item).strip()
  586. for item in focus if str(item).strip()]
  587. if normalized_focus:
  588. plan["content_focus"] = normalized_focus
  589. elif isinstance(focus, str) and focus.strip():
  590. plan["content_focus"] = [item.strip()
  591. for item in re.split(r"[、,,;\n]", focus) if item.strip()]
  592. if "培训" not in plan["topic"]:
  593. plan["topic"] = f"{plan['topic']}安全培训"
  594. if "PPT大纲" not in plan["template"]:
  595. plan["template"] = f"{plan['template']}PPT大纲"
  596. return plan
  597. def _build_safety_training_generation_message(message: str, plan: dict) -> str:
  598. payload = _extract_safety_training_request_payload(message)
  599. focus_text = "、".join(plan.get("content_focus") or [])
  600. lines = [
  601. "输出类型:安全培训PPT大纲",
  602. "请基于以下结构化需求生成安全培训PPT大纲,不要生成通知正文,不要切换到其他文档生成任务。",
  603. f"主题:{plan.get('topic') or '安全培训'}",
  604. f"模板:{plan.get('template') or '标准安全培训PPT大纲'}",
  605. f"内容要点:{focus_text or '安全生产责任、风险识别、应急处置、安全意识提升'}",
  606. f"参训对象:{plan.get('audience') or '参训员工'}",
  607. f"培训时间:{plan.get('time') or '未指定'}",
  608. f"培训地点:{plan.get('location') or '未指定'}",
  609. f"培训目标:{plan.get('goal') or '提升参训人员安全意识和风险防控能力'}",
  610. f"其他要求:{plan.get('notes') or '无'}",
  611. f"归一化需求:{plan.get('normalized_request') or ''}",
  612. f"原始需求:{payload['request'] or message}",
  613. ]
  614. if payload["filename"] or payload["document_content"]:
  615. lines.extend([
  616. f"上传文档名称:{payload['filename'] or '未命名文档'}",
  617. f"上传文档大小:{payload['filesize'] or '未知'}",
  618. "上传文档内容:",
  619. payload["document_content"] or "无",
  620. ])
  621. return "\n".join(lines)
  622. async def _infer_safety_training_plan(message: str) -> dict:
  623. payload = _extract_safety_training_request_payload(message)
  624. planning_input = payload["request"] or message
  625. if payload["document_content"]:
  626. planning_input = (
  627. f"{planning_input}\n\n"
  628. f"上传文档名称:{payload['filename'] or '未命名文档'}\n"
  629. f"上传文档内容摘要:{payload['document_content'][:3000]}"
  630. )
  631. try:
  632. response = await qwen_service.chat([
  633. {"role": "system", "content": SAFETY_TRAINING_PLAN_SYSTEM_PROMPT},
  634. {"role": "user", "content": planning_input},
  635. ])
  636. return _normalize_safety_training_plan(message, _parse_json_object(response))
  637. except Exception as e:
  638. logger.warning(
  639. f"[safety_training] 需求整理失败,使用兜底结构: {type(e).__name__}: {e}")
  640. return _build_fallback_safety_training_plan(message)
  641. def _clean_ai_writing_response(content: str) -> str:
  642. text = str(content or "").strip()
  643. if not text:
  644. return ""
  645. text = re.sub(r"```(?:html)?\s*", "", text,
  646. flags=re.IGNORECASE).replace("```", "").strip()
  647. body_match = re.search(
  648. r"<body[^>]*>(.*?)</body>", text, re.IGNORECASE | re.DOTALL)
  649. if body_match:
  650. text = body_match.group(1).strip()
  651. first_content_tag = re.search(
  652. r"<(?:article|section|main|div|h[1-6]|p|table|ul|ol)\b",
  653. text,
  654. re.IGNORECASE,
  655. )
  656. if first_content_tag and text[:first_content_tag.start()].strip():
  657. text = text[first_content_tag.start():]
  658. cleanup_patterns = (
  659. r"<!DOCTYPE[^>]*>",
  660. r"<html[^>]*>",
  661. r"</html>",
  662. r"<head[^>]*>.*?</head>",
  663. r"<body[^>]*>",
  664. r"</body>",
  665. r"<style[^>]*>.*?</style>",
  666. r"<script[^>]*>.*?</script>",
  667. r"<meta[^>]*>",
  668. r"<title[^>]*>.*?</title>",
  669. )
  670. for pattern in cleanup_patterns:
  671. text = re.sub(pattern, "", text, flags=re.IGNORECASE | re.DOTALL)
  672. return text.strip()
  673. async def _generate_ai_writing_response(message: str) -> str:
  674. rag_context = await _rag_search(message, top_k=10)
  675. system_content = load_prompt(
  676. "document_writing",
  677. userMessage=message,
  678. contextJSON=rag_context if rag_context else "暂无相关知识库内容",
  679. )
  680. messages = [
  681. {"role": "system", "content": system_content},
  682. {
  683. "role": "user",
  684. "content": (
  685. "请根据上面的写作规范和我的原始需求,直接生成可放入富文本编辑器的公文正文 HTML 片段。"
  686. "不要输出道歉、解释、DOCTYPE、html、head、body、style 或 script 标签。\n\n"
  687. f"原始需求:\n{message}"
  688. ),
  689. },
  690. ]
  691. raw_response = await deepseek_service.chat(messages)
  692. raw_thinking, raw_answer = split_thinking_and_answer(raw_response)
  693. answer_text = _clean_ai_writing_response(raw_answer or raw_response)
  694. # AI写作输出纯HTML文档内容,不附加思考过程(避免混入纯文本破坏HTML结构)
  695. return answer_text
  696. async def _generate_ppt_outline_response(message: str) -> str:
  697. training_plan = await _infer_safety_training_plan(message)
  698. generation_message = _build_safety_training_generation_message(
  699. message, training_plan)
  700. rag_context = await _rag_search(generation_message, top_k=10)
  701. system_content = load_prompt(
  702. "ppt_outline",
  703. userMessage=generation_message,
  704. contextJSON=rag_context if rag_context else "暂无相关知识库内容",
  705. )
  706. messages = [
  707. {"role": "system", "content": system_content},
  708. {"role": "user", "content": "请直接输出安全培训PPT大纲正文,从标题开始,不要解释提示词或安全规则。"},
  709. ]
  710. raw_response = await qwen_service.chat(messages)
  711. raw_thinking, raw_answer = split_thinking_and_answer(raw_response)
  712. answer_text = raw_answer or raw_response
  713. if raw_thinking:
  714. thinking_summary = await summarize_thinking_content(
  715. user_question=message,
  716. raw_thinking=raw_thinking,
  717. final_answer=answer_text,
  718. chat_service=qwen_service,
  719. context="ppt_outline",
  720. )
  721. return (
  722. f"思考过程:\n{thinking_summary}\n\n回答:\n{answer_text}"
  723. if thinking_summary
  724. else answer_text
  725. )
  726. return answer_text
  727. def _persist_message_pair(db: Session, conv_id: int, user, user_content: str, ai_content: str):
  728. now_ts = int(time.time())
  729. user_message = AIMessage(
  730. ai_conversation_id=conv_id,
  731. user_id=user.user_id,
  732. type="user",
  733. content=user_content,
  734. created_at=now_ts,
  735. updated_at=now_ts,
  736. is_deleted=0,
  737. )
  738. db.add(user_message)
  739. db.commit()
  740. db.refresh(user_message)
  741. ai_message = AIMessage(
  742. ai_conversation_id=conv_id,
  743. user_id=user.user_id,
  744. type="ai",
  745. content=ai_content,
  746. prev_user_id=user_message.id,
  747. created_at=now_ts,
  748. updated_at=now_ts,
  749. is_deleted=0,
  750. )
  751. db.add(ai_message)
  752. db.commit()
  753. db.refresh(ai_message)
  754. return user_message, ai_message
  755. def _build_history_messages(conv_id: int, limit: int = 10) -> list:
  756. """从数据库读取最近对话历史,构建 messages 列表"""
  757. db = SessionLocal()
  758. try:
  759. msgs = (
  760. db.query(AIMessage)
  761. .filter(AIMessage.ai_conversation_id == conv_id, AIMessage.is_deleted == 0)
  762. .order_by(AIMessage.id.desc())
  763. .limit(limit)
  764. .all()
  765. )
  766. msgs.reverse()
  767. result = []
  768. for m in msgs:
  769. role = "user" if m.type == "user" else "assistant"
  770. if m.content:
  771. result.append({"role": role, "content": m.content})
  772. return result
  773. finally:
  774. db.close()
  775. # ─────────────────────────────────────────────────────────────────────────
  776. # 非流式接口
  777. # ─────────────────────────────────────────────────────────────────────────
  778. class SendMessageRequest(BaseModel):
  779. message: str
  780. conversation_id: Optional[int] = None
  781. ai_conversation_id: Optional[int] = None
  782. business_type: int = 0 # 0=AI问答, 1=PPT大纲, 2=AI写作, 3=考试工坊
  783. exam_name: str = ""
  784. ai_message_id: int = 0
  785. @router.post("/send_deepseek_message")
  786. async def send_deepseek_message(
  787. request: Request,
  788. data: SendMessageRequest,
  789. db: Session = Depends(get_db),
  790. ):
  791. """
  792. 发送消息(非流式)
  793. 支持多种业务类型:
  794. - 0: AI问答(意图识别 + RAG)
  795. - 1: PPT大纲生成
  796. - 2: AI写作
  797. - 3: 考试工坊
  798. """
  799. user = request.state.user
  800. if not user:
  801. return {"statusCode": 401, "msg": "未授权"}
  802. try:
  803. message = data.message.strip()
  804. if not message:
  805. return {"statusCode": 400, "msg": "消息不能为空"}
  806. conversation_id = data.conversation_id or data.ai_conversation_id
  807. # 创建或获取对话
  808. if not conversation_id:
  809. conversation = AIConversation(
  810. user_id=user.user_id,
  811. content=message[:100],
  812. business_type=data.business_type,
  813. exam_name=data.exam_name if data.business_type == 3 else "",
  814. created_at=int(time.time()),
  815. updated_at=int(time.time()),
  816. is_deleted=0,
  817. )
  818. db.add(conversation)
  819. db.commit()
  820. db.refresh(conversation)
  821. conv_id = conversation.id
  822. else:
  823. conv_id = conversation_id
  824. db.query(AIConversation).filter(
  825. AIConversation.id == conv_id,
  826. AIConversation.user_id == user.user_id,
  827. AIConversation.is_deleted == 0,
  828. ).update({
  829. "content": message[:100],
  830. "business_type": data.business_type,
  831. "exam_name": data.exam_name if data.business_type == 3 else "",
  832. "updated_at": int(time.time()),
  833. })
  834. db.commit()
  835. response_text = ""
  836. ai_message_id = 0
  837. if data.business_type == 0:
  838. # AI问答:意图识别 + RAG
  839. try:
  840. intent_result = await qwen_service.intent_recognition(message)
  841. intent_type = ""
  842. if isinstance(intent_result, dict):
  843. intent_type = (
  844. intent_result.get("intent_type") or intent_result.get(
  845. "intent") or ""
  846. ).lower()
  847. rag_context = ""
  848. if intent_type in ("query_knowledge_base", "知识库查询", "技术咨询"):
  849. rag_context = await _rag_search(message, top_k=10)
  850. # 使用prompt加载器加载最终回答prompt
  851. system_content = load_prompt(
  852. "final_answer",
  853. userMessage=message,
  854. contextJSON=rag_context if rag_context else "暂无相关知识库内容"
  855. )
  856. messages = [
  857. {"role": "user", "content": system_content},
  858. ]
  859. qwen_response = await qwen_service.chat(messages)
  860. raw_thinking, raw_answer = split_thinking_and_answer(
  861. qwen_response)
  862. answer_source = raw_answer or qwen_response
  863. # 兼容模型直接返回 JSON 的场景
  864. answer_text = answer_source
  865. try:
  866. if isinstance(answer_source, str) and answer_source.strip().startswith("{"):
  867. response_json = json.loads(answer_source)
  868. answer_text = response_json.get(
  869. "natural_language_answer", answer_source
  870. )
  871. except Exception:
  872. answer_text = answer_source
  873. if raw_thinking:
  874. thinking_summary = await summarize_thinking_content(
  875. user_question=message,
  876. raw_thinking=raw_thinking,
  877. final_answer=answer_text,
  878. chat_service=qwen_service,
  879. context="send_message",
  880. )
  881. response_text = (
  882. f"思考过程:\n{thinking_summary}\n\n回答:\n{answer_text}"
  883. if thinking_summary
  884. else answer_text
  885. )
  886. else:
  887. response_text = answer_text
  888. except Exception as e:
  889. error_detail = str(e).strip() if str(
  890. e).strip() else f"未知错误({type(e).__name__})"
  891. logger.error(
  892. f"[send_deepseek_message] AI问答异常: {type(e).__name__}: {error_detail}")
  893. response_text = f"处理失败: {error_detail}"
  894. elif data.business_type == 1:
  895. # PPT大纲生成
  896. try:
  897. response_text = await _generate_ppt_outline_response(message)
  898. _, ai_message = _persist_message_pair(
  899. db=db,
  900. conv_id=conv_id,
  901. user=user,
  902. user_content=message,
  903. ai_content=response_text,
  904. )
  905. ai_message_id = ai_message.id
  906. _refresh_conversation_snapshot(db, conv_id, user.user_id)
  907. db.commit()
  908. return {
  909. "statusCode": 200,
  910. "msg": "success",
  911. "data": {
  912. "conversation_id": conv_id,
  913. "ai_conversation_id": conv_id,
  914. "response": response_text,
  915. "reply": response_text,
  916. "content": response_text,
  917. "message": response_text,
  918. "ai_message_id": ai_message_id,
  919. "user_id": user.user_id,
  920. "business_type": data.business_type,
  921. },
  922. }
  923. except Exception as e:
  924. error_detail = str(e).strip() if str(
  925. e).strip() else f"未知错误({type(e).__name__})"
  926. logger.error(
  927. f"[send_deepseek_message] PPT大纲生成异常: {type(e).__name__}: {error_detail}")
  928. response_text = f"处理失败: {error_detail}"
  929. elif data.business_type == 2:
  930. # AI写作
  931. try:
  932. response_text = await _generate_ai_writing_response(message)
  933. _, ai_message = _persist_message_pair(
  934. db=db,
  935. conv_id=conv_id,
  936. user=user,
  937. user_content=message,
  938. ai_content=response_text,
  939. )
  940. ai_message_id = ai_message.id
  941. _refresh_conversation_snapshot(db, conv_id, user.user_id)
  942. db.commit()
  943. return {
  944. "statusCode": 200,
  945. "msg": "success",
  946. "data": {
  947. "conversation_id": conv_id,
  948. "ai_conversation_id": conv_id,
  949. "response": response_text,
  950. "reply": response_text,
  951. "content": response_text,
  952. "message": response_text,
  953. "ai_message_id": ai_message_id,
  954. "user_id": user.user_id,
  955. "business_type": data.business_type,
  956. },
  957. }
  958. except Exception as e:
  959. error_detail = str(e).strip() if str(
  960. e).strip() else f"未知错误({type(e).__name__})"
  961. logger.error(
  962. f"[send_deepseek_message] AI写作异常: {type(e).__name__}: {error_detail}")
  963. response_text = f"处理失败: {error_detail}"
  964. elif data.business_type == 3:
  965. # 考试工坊:生成题目
  966. try:
  967. system_content = (
  968. "你是一个专业的考试题目生成助手,专注于路桥隧轨施工安全领域。\n"
  969. "请根据用户需求生成专业的考试题目,包括单选题、多选题、判断题、简答题等。\n"
  970. "用户消息中已经包含考试标题、题型要求和出题依据内容,必须以其中的出题依据内容为核心生成题目,不能脱离依据内容自由发挥。\n"
  971. "题干、选项、答案和解析都要与出题依据内容中的知识点、专业术语、操作流程、规范要求或培训主题直接相关。\n"
  972. "输出必须是可直接 JSON.parse 的纯 JSON,不要包含 markdown 代码块、解释文字或额外前后缀。\n"
  973. "JSON 顶层结构必须包含 singleChoice、judge、multiple、short 四个字段。\n"
  974. "singleChoice.questions 和 multiple.questions 中每道题必须包含 text、options、answer、analysis。\n"
  975. "options 必须是数组,元素格式为 {\"key\":\"A\",\"text\":\"具体选项内容\"},禁止输出“选项A”这类占位文本。\n"
  976. "judge.questions 中每道题必须包含 text、answer、analysis。\n"
  977. "short.questions 中每道题必须包含 text、outline,其中 outline 至少包含 keyFactors。\n"
  978. "所有题目内容、选项内容、答案和解析都要结合用户给出的工程类型、题型数量、分值和课件内容具体生成。"
  979. )
  980. messages = [
  981. {"role": "system", "content": system_content},
  982. {"role": "user", "content": message},
  983. ]
  984. raw_response_text = await qwen_service.chat(messages)
  985. response_text = _sanitize_exam_response(raw_response_text)
  986. now_ts = int(time.time())
  987. user_message = AIMessage(
  988. ai_conversation_id=conv_id,
  989. user_id=user.user_id,
  990. type="user",
  991. content=message,
  992. created_at=now_ts,
  993. updated_at=now_ts,
  994. is_deleted=0,
  995. )
  996. db.add(user_message)
  997. db.commit()
  998. db.refresh(user_message)
  999. ai_message = AIMessage(
  1000. ai_conversation_id=conv_id,
  1001. user_id=user.user_id,
  1002. type="ai",
  1003. content=response_text,
  1004. prev_user_id=user_message.id,
  1005. created_at=now_ts,
  1006. updated_at=now_ts,
  1007. is_deleted=0,
  1008. )
  1009. db.add(ai_message)
  1010. db.commit()
  1011. _refresh_conversation_snapshot(db, conv_id, user.user_id)
  1012. db.commit()
  1013. generated_title = data.exam_name
  1014. if not generated_title:
  1015. try:
  1016. import json
  1017. exam_data = json.loads(response_text)
  1018. generated_title = exam_data.get("title") or exam_data.get("exam_name") or exam_data.get(
  1019. "examTitle") or exam_data.get("试卷标题") or exam_data.get("标题") or ""
  1020. except Exception:
  1021. pass
  1022. if generated_title:
  1023. db.query(AIConversation).filter(AIConversation.id == conv_id).update(
  1024. {"exam_name": generated_title,
  1025. "updated_at": int(time.time())}
  1026. )
  1027. db.commit()
  1028. except Exception as e:
  1029. error_detail = str(e).strip() if str(
  1030. e).strip() else f"未知错误({type(e).__name__})"
  1031. logger.error(
  1032. f"[send_deepseek_message] 考试工坊异常: {type(e).__name__}: {error_detail}")
  1033. response_text = f"处理失败: {error_detail}"
  1034. else:
  1035. return {"statusCode": 400, "msg": f"不支持的业务类型: {data.business_type}"}
  1036. return {
  1037. "statusCode": 200,
  1038. "msg": "success",
  1039. "data": {
  1040. "conversation_id": conv_id,
  1041. "ai_conversation_id": conv_id,
  1042. "response": response_text,
  1043. "reply": response_text,
  1044. "content": response_text,
  1045. "message": response_text,
  1046. "user_id": user.user_id,
  1047. "business_type": data.business_type,
  1048. },
  1049. }
  1050. except Exception as e:
  1051. logger.error(f"[send_deepseek_message] 异常: {e}")
  1052. return {"statusCode": 500, "msg": f"处理失败: {str(e)}"}
  1053. @router.get("/get_history_record")
  1054. async def get_history_record(
  1055. request: Request,
  1056. ai_conversation_id: int = 0,
  1057. business_type: Optional[int] = None,
  1058. db: Session = Depends(get_db),
  1059. ):
  1060. """兼容前端的历史记录查询:ai_conversation_id=0 返回对话列表,否则返回消息详情。"""
  1061. user = request.state.user
  1062. if not user:
  1063. return {"statusCode": 401, "msg": "未授权"}
  1064. if ai_conversation_id > 0:
  1065. messages = (
  1066. db.query(AIMessage)
  1067. .filter(
  1068. AIMessage.ai_conversation_id == ai_conversation_id,
  1069. AIMessage.user_id == user.user_id,
  1070. AIMessage.is_deleted == 0,
  1071. )
  1072. .order_by(AIMessage.id.asc())
  1073. .all()
  1074. )
  1075. return {
  1076. "statusCode": 200,
  1077. "msg": "success",
  1078. "total": len(messages),
  1079. "data": [
  1080. {
  1081. "id": message.id,
  1082. "ai_conversation_id": message.ai_conversation_id,
  1083. "user_id": message.user_id,
  1084. "type": message.type,
  1085. "content": message.content,
  1086. "user_feedback": message.user_feedback,
  1087. "prev_user_id": message.prev_user_id,
  1088. "search_source": message.search_source or "",
  1089. "guess_you_want": message.guess_you_want or "",
  1090. "created_at": _to_frontend_timestamp(message.created_at),
  1091. "updated_at": _to_frontend_timestamp(message.updated_at),
  1092. }
  1093. for message in messages
  1094. ],
  1095. }
  1096. conversations_query = db.query(AIConversation).filter(
  1097. AIConversation.user_id == user.user_id,
  1098. AIConversation.is_deleted == 0,
  1099. )
  1100. if business_type is not None:
  1101. conversations_query = conversations_query.filter(
  1102. AIConversation.business_type == business_type
  1103. )
  1104. total = conversations_query.count()
  1105. conversations = (
  1106. conversations_query
  1107. .order_by(AIConversation.updated_at.desc(), AIConversation.id.desc())
  1108. .limit(50)
  1109. .all()
  1110. )
  1111. return {
  1112. "statusCode": 200,
  1113. "msg": "success",
  1114. "total": total,
  1115. "data": [
  1116. {
  1117. "id": conv.id,
  1118. "title": _build_conversation_title(conv),
  1119. "content": conv.content or "",
  1120. "business_type": conv.business_type,
  1121. "exam_name": conv.exam_name or "",
  1122. "created_at": _to_frontend_timestamp(conv.created_at),
  1123. "updated_at": _to_frontend_timestamp(conv.updated_at),
  1124. }
  1125. for conv in conversations
  1126. ],
  1127. }
  1128. class DeleteConversationRequest(BaseModel):
  1129. ai_conversation_id: int = 0
  1130. ai_message_id: int = 0
  1131. @router.post("/delete_conversation")
  1132. async def delete_conversation(
  1133. request: Request, data: DeleteConversationRequest, db: Session = Depends(get_db)
  1134. ):
  1135. """
  1136. 删除对话(软删除)
  1137. 同时软删除对话记录和所有关联的消息
  1138. """
  1139. user = request.state.user
  1140. if not user:
  1141. return {"statusCode": 401, "msg": "未授权"}
  1142. now_ts = int(time.time())
  1143. if data.ai_message_id:
  1144. ai_message = (
  1145. db.query(AIMessage)
  1146. .filter(
  1147. AIMessage.id == data.ai_message_id,
  1148. AIMessage.user_id == user.user_id,
  1149. AIMessage.type == "ai",
  1150. AIMessage.is_deleted == 0,
  1151. )
  1152. .first()
  1153. )
  1154. if not ai_message:
  1155. return {"statusCode": 404, "msg": "消息不存在"}
  1156. db.query(AIMessage).filter(
  1157. AIMessage.id == ai_message.id,
  1158. AIMessage.user_id == user.user_id,
  1159. ).update({"is_deleted": 1, "updated_at": now_ts})
  1160. if ai_message.prev_user_id:
  1161. db.query(AIMessage).filter(
  1162. AIMessage.id == ai_message.prev_user_id,
  1163. AIMessage.user_id == user.user_id,
  1164. AIMessage.ai_conversation_id == ai_message.ai_conversation_id,
  1165. ).update({"is_deleted": 1, "updated_at": now_ts})
  1166. _refresh_conversation_snapshot(
  1167. db, ai_message.ai_conversation_id, user.user_id)
  1168. db.commit()
  1169. return {"statusCode": 200, "msg": "删除成功"}
  1170. if not data.ai_conversation_id:
  1171. return {"statusCode": 400, "msg": "缺少删除参数"}
  1172. db.query(AIConversation).filter(
  1173. AIConversation.id == data.ai_conversation_id,
  1174. AIConversation.user_id == user.user_id,
  1175. ).update({"is_deleted": 1, "updated_at": now_ts})
  1176. db.query(AIMessage).filter(
  1177. AIMessage.ai_conversation_id == data.ai_conversation_id,
  1178. AIMessage.user_id == user.user_id,
  1179. ).update({"is_deleted": 1, "updated_at": now_ts})
  1180. db.commit()
  1181. return {"statusCode": 200, "msg": "删除成功"}
  1182. class DeleteHistoryRequest(BaseModel):
  1183. ai_conversation_id: int
  1184. @router.post("/delete_history_record")
  1185. async def delete_history_record(
  1186. request: Request, data: DeleteHistoryRequest, db: Session = Depends(get_db)
  1187. ):
  1188. """删除历史记录(软删除)"""
  1189. user = request.state.user
  1190. if not user:
  1191. return {"statusCode": 401, "msg": "未授权"}
  1192. db.query(AIConversation).filter(
  1193. AIConversation.id == data.ai_conversation_id,
  1194. AIConversation.user_id == user.user_id,
  1195. ).update({"is_deleted": 1, "updated_at": int(time.time())})
  1196. db.commit()
  1197. return {"statusCode": 200, "msg": "删除成功"}
  1198. # ─────────────────────────────────────────────────────────────────────────
  1199. # 流式接口 /stream/chat(无 DB,意图识别 + RAG)
  1200. # ─────────────────────────────────────────────────────────────────────────
  1201. class StreamChatRequest(BaseModel):
  1202. message: str
  1203. model: str = ""
  1204. @router.post("/stream/chat")
  1205. async def stream_chat(request: Request, data: StreamChatRequest):
  1206. """流式聊天(SSE,不写 DB)"""
  1207. message = data.message.strip()
  1208. if not message:
  1209. return JSONResponse(content={"statusCode": 400, "msg": "消息不能为空"})
  1210. async def event_generator():
  1211. intent_type = ""
  1212. try:
  1213. intent_result = await qwen_service.intent_recognition(message)
  1214. if isinstance(intent_result, dict):
  1215. intent_type = (
  1216. intent_result.get("intent_type") or intent_result.get(
  1217. "intent") or ""
  1218. ).lower()
  1219. except Exception as ie:
  1220. logger.warning(f"[stream/chat] 意图识别异常: {ie}")
  1221. rag_context = ""
  1222. if intent_type in ("query_knowledge_base", "知识库查询", "技术咨询"):
  1223. rag_context = await _rag_search(message)
  1224. # 使用prompt加载器加载最终回答prompt
  1225. system_content = load_prompt(
  1226. "final_answer",
  1227. userMessage=message,
  1228. contextJSON=rag_context if rag_context else "暂无相关知识库内容"
  1229. )
  1230. messages = [
  1231. {"role": "user", "content": system_content},
  1232. ]
  1233. try:
  1234. buffer = ""
  1235. pre_answer = ""
  1236. thinking_buf = ""
  1237. in_think = False
  1238. thinking_done = False
  1239. max_input_chars = getattr(
  1240. settings.thinking_summary, "max_input_chars", 1500)
  1241. async for chunk in qwen_service.stream_chat(messages):
  1242. buffer += chunk
  1243. while buffer:
  1244. lower = buffer.lower()
  1245. if not thinking_done:
  1246. if not in_think:
  1247. start_idx = lower.find("<think>")
  1248. if start_idx == -1:
  1249. yield f"data: {json.dumps({'content': buffer}, ensure_ascii=False)}\n\n"
  1250. buffer = ""
  1251. break
  1252. pre_answer += buffer[:start_idx]
  1253. buffer = buffer[start_idx + len("<think>"):]
  1254. in_think = True
  1255. continue
  1256. end_idx = lower.find("</think>")
  1257. if end_idx == -1:
  1258. if max_input_chars and len(thinking_buf) < max_input_chars:
  1259. thinking_buf += buffer[: max_input_chars -
  1260. len(thinking_buf)]
  1261. buffer = ""
  1262. break
  1263. if max_input_chars and len(thinking_buf) < max_input_chars:
  1264. thinking_part = buffer[:end_idx]
  1265. thinking_buf += thinking_part[: max_input_chars - len(
  1266. thinking_buf)]
  1267. buffer = buffer[end_idx + len("</think>"):]
  1268. in_think = False
  1269. thinking_done = True
  1270. thinking_summary = await summarize_thinking_content(
  1271. user_question=message,
  1272. raw_thinking=thinking_buf,
  1273. final_answer="",
  1274. chat_service=qwen_service,
  1275. context="stream_chat",
  1276. )
  1277. if thinking_summary:
  1278. prefix = f"思考过程:\n{thinking_summary}\n\n回答:\n"
  1279. yield f"data: {json.dumps({'content': prefix}, ensure_ascii=False)}\n\n"
  1280. answer_chunk = (pre_answer + buffer).lstrip()
  1281. if answer_chunk:
  1282. yield f"data: {json.dumps({'content': answer_chunk}, ensure_ascii=False)}\n\n"
  1283. pre_answer = ""
  1284. buffer = ""
  1285. break
  1286. yield f"data: {json.dumps({'content': buffer}, ensure_ascii=False)}\n\n"
  1287. buffer = ""
  1288. # 流结束但未遇到 </think>:仅尝试生成要点(不回退输出 raw thinking)
  1289. if in_think and not thinking_done and thinking_buf:
  1290. thinking_summary = await summarize_thinking_content(
  1291. user_question=message,
  1292. raw_thinking=thinking_buf,
  1293. final_answer="",
  1294. chat_service=qwen_service,
  1295. context="stream_chat_eof",
  1296. )
  1297. if thinking_summary:
  1298. prefix = f"思考过程:\n{thinking_summary}\n\n回答:\n"
  1299. yield f"data: {json.dumps({'content': prefix}, ensure_ascii=False)}\n\n"
  1300. if pre_answer:
  1301. yield f"data: {json.dumps({'content': pre_answer}, ensure_ascii=False)}\n\n"
  1302. except Exception as e:
  1303. logger.error(f"[stream/chat] 流式输出异常: {e}")
  1304. yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
  1305. finally:
  1306. yield "data: [DONE]\n\n"
  1307. return StreamingResponse(event_generator(), media_type="text/event-stream")
  1308. # ─────────────────────────────────────────────────────────────────────────
  1309. # 流式接口 /stream/chat-with-db(前端主聊天接口)
  1310. # ─────────────────────────────────────────────────────────────────────────
  1311. class StreamChatWithDBRequest(BaseModel):
  1312. message: str
  1313. ai_conversation_id: int = 0
  1314. business_type: int = 0
  1315. exam_name: str = ""
  1316. ai_message_id: int = 0
  1317. online_search_content: str = ""
  1318. @router.post("/stream/chat-with-db")
  1319. async def stream_chat_with_db(request: Request, data: StreamChatWithDBRequest):
  1320. """
  1321. 带 DB 操作的流式聊天(SSE)
  1322. 流程:
  1323. 1. 创建/获取对话
  1324. 2. 插入用户消息和 AI 占位消息
  1325. 3. 发送 initial 事件
  1326. 4. RAG 检索
  1327. 5. 构建历史上下文
  1328. 6. 流式输出
  1329. 7. 更新 AI 消息内容
  1330. """
  1331. user = request.state.user
  1332. if not user:
  1333. return JSONResponse(content={"statusCode": 401, "msg": "未授权"})
  1334. message = data.message.strip()
  1335. if not message:
  1336. return JSONResponse(content={"statusCode": 400, "msg": "消息不能为空"})
  1337. async def event_generator():
  1338. db = SessionLocal()
  1339. try:
  1340. # 1. 创建或获取对话
  1341. if data.ai_conversation_id == 0:
  1342. conversation = AIConversation(
  1343. user_id=user.user_id,
  1344. content=_build_conversation_preview(message, limit=100),
  1345. business_type=data.business_type,
  1346. exam_name=data.exam_name,
  1347. created_at=int(time.time()),
  1348. updated_at=int(time.time()),
  1349. is_deleted=0,
  1350. )
  1351. db.add(conversation)
  1352. db.commit()
  1353. db.refresh(conversation)
  1354. conv_id = conversation.id
  1355. else:
  1356. existing_conversation = (
  1357. db.query(AIConversation)
  1358. .filter(
  1359. AIConversation.id == data.ai_conversation_id,
  1360. AIConversation.user_id == user.user_id,
  1361. AIConversation.is_deleted == 0,
  1362. )
  1363. .first()
  1364. )
  1365. if existing_conversation:
  1366. conv_id = existing_conversation.id
  1367. db.query(AIConversation).filter(
  1368. AIConversation.id == conv_id,
  1369. AIConversation.user_id == user.user_id,
  1370. ).update(
  1371. {
  1372. "content": _build_conversation_preview(message, limit=100),
  1373. "business_type": data.business_type,
  1374. "exam_name": data.exam_name if data.business_type == 3 else "",
  1375. "updated_at": int(time.time()),
  1376. }
  1377. )
  1378. db.commit()
  1379. else:
  1380. conversation = AIConversation(
  1381. user_id=user.user_id,
  1382. content=_build_conversation_preview(
  1383. message, limit=100),
  1384. business_type=data.business_type,
  1385. exam_name=data.exam_name if data.business_type == 3 else "",
  1386. created_at=int(time.time()),
  1387. updated_at=int(time.time()),
  1388. is_deleted=0,
  1389. )
  1390. db.add(conversation)
  1391. db.commit()
  1392. db.refresh(conversation)
  1393. conv_id = conversation.id
  1394. # 2. 插入用户消息
  1395. user_msg = AIMessage(
  1396. ai_conversation_id=conv_id,
  1397. user_id=user.user_id,
  1398. type="user",
  1399. content=message,
  1400. created_at=int(time.time()),
  1401. updated_at=int(time.time()),
  1402. is_deleted=0,
  1403. )
  1404. db.add(user_msg)
  1405. db.commit()
  1406. db.refresh(user_msg)
  1407. # 3. 插入 AI 占位消息
  1408. ai_msg = AIMessage(
  1409. ai_conversation_id=conv_id,
  1410. user_id=user.user_id,
  1411. type="ai",
  1412. content="",
  1413. prev_user_id=user_msg.id,
  1414. created_at=int(time.time()),
  1415. updated_at=int(time.time()),
  1416. is_deleted=0,
  1417. )
  1418. db.add(ai_msg)
  1419. db.commit()
  1420. db.refresh(ai_msg)
  1421. # 4. 发送 initial 事件
  1422. yield f"data: {json.dumps({'type': 'initial', 'ai_conversation_id': conv_id, 'ai_message_id': ai_msg.id}, ensure_ascii=False)}\n\n"
  1423. # 5. RAG search
  1424. rag_context = await _rag_search(message, top_k=10)
  1425. if data.business_type in (1, 2):
  1426. # PPT outline / AI writing: use dedicated prompt
  1427. prompt_name = "ppt_outline" if data.business_type == 1 else "document_writing"
  1428. system_content = load_prompt(
  1429. prompt_name,
  1430. userMessage=message,
  1431. contextJSON=rag_context if rag_context else "?????????"
  1432. )
  1433. messages = [
  1434. {"role": "user", "content": system_content},
  1435. ]
  1436. else:
  1437. # 6. History context (last 4 items, 2 turns)
  1438. history_msgs = (
  1439. db.query(AIMessage)
  1440. .filter(
  1441. AIMessage.ai_conversation_id == conv_id,
  1442. AIMessage.id < ai_msg.id,
  1443. AIMessage.is_deleted == 0,
  1444. )
  1445. .order_by(AIMessage.updated_at.desc())
  1446. .limit(4)
  1447. .all()
  1448. )
  1449. history_msgs.reverse()
  1450. history_context = ""
  1451. for msg in history_msgs:
  1452. role = "??" if msg.type == "user" else "??"
  1453. history_context += f"{role}: {msg.content}\n\n"
  1454. # 7. Build final prompt
  1455. context_parts = []
  1456. if rag_context:
  1457. context_parts.append(f"??????\n{rag_context}")
  1458. if data.online_search_content:
  1459. context_parts.append(
  1460. f"???????\n{data.online_search_content}")
  1461. context_json = "\n\n".join(
  1462. context_parts) if context_parts else "?????????"
  1463. system_content = load_prompt(
  1464. "final_answer",
  1465. userMessage=message,
  1466. contextJSON=context_json,
  1467. historyContext=history_context if history_context else ""
  1468. )
  1469. messages = [
  1470. {"role": "user", "content": system_content},
  1471. ]
  1472. # 8. 流式输出并收集完整回复
  1473. full_response = ""
  1474. try:
  1475. summary_enabled = getattr(
  1476. settings.thinking_summary, "enabled", True)
  1477. max_input_chars = getattr(
  1478. settings.thinking_summary, "max_input_chars", 1500)
  1479. buffer = ""
  1480. pre_answer = ""
  1481. thinking_buf = ""
  1482. in_think = False
  1483. thinking_done = False
  1484. async for chunk in qwen_service.stream_chat(messages):
  1485. if not summary_enabled:
  1486. escaped_chunk = chunk.replace("\n", "\\n")
  1487. full_response += chunk
  1488. yield f"data: {escaped_chunk}\n\n"
  1489. continue
  1490. buffer += chunk
  1491. while buffer:
  1492. lower = buffer.lower()
  1493. if not thinking_done:
  1494. if not in_think:
  1495. start_idx = lower.find("<think>")
  1496. if start_idx == -1:
  1497. escaped_text = buffer.replace("\n", "\\n")
  1498. full_response += buffer
  1499. yield f"data: {escaped_text}\n\n"
  1500. buffer = ""
  1501. break
  1502. pre_answer += buffer[:start_idx]
  1503. buffer = buffer[start_idx + len("<think>"):]
  1504. in_think = True
  1505. continue
  1506. end_idx = lower.find("</think>")
  1507. if end_idx == -1:
  1508. if max_input_chars and len(thinking_buf) < max_input_chars:
  1509. thinking_buf += buffer[: max_input_chars -
  1510. len(thinking_buf)]
  1511. buffer = ""
  1512. break
  1513. if max_input_chars and len(thinking_buf) < max_input_chars:
  1514. thinking_part = buffer[:end_idx]
  1515. thinking_buf += thinking_part[: max_input_chars - len(
  1516. thinking_buf)]
  1517. buffer = buffer[end_idx + len("</think>"):]
  1518. in_think = False
  1519. thinking_done = True
  1520. thinking_summary = await summarize_thinking_content(
  1521. user_question=message,
  1522. raw_thinking=thinking_buf,
  1523. final_answer="",
  1524. chat_service=qwen_service,
  1525. context="stream_chat_with_db",
  1526. )
  1527. if thinking_summary:
  1528. prefix = f"思考过程:\n{thinking_summary}\n\n回答:\n"
  1529. full_response += prefix
  1530. escaped_prefix = prefix.replace('\n', '\\n')
  1531. yield f"data: {escaped_prefix}\n\n"
  1532. answer_chunk = (pre_answer + buffer).lstrip()
  1533. if answer_chunk:
  1534. full_response += answer_chunk
  1535. escaped_answer = answer_chunk.replace(
  1536. '\n', '\\n')
  1537. yield f"data: {escaped_answer}\n\n"
  1538. pre_answer = ""
  1539. buffer = ""
  1540. break
  1541. escaped_text = buffer.replace("\n", "\\n")
  1542. full_response += buffer
  1543. yield f"data: {escaped_text}\n\n"
  1544. buffer = ""
  1545. except Exception as e:
  1546. logger.error(f"[stream/chat-with-db] 流式输出异常: {e}")
  1547. yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
  1548. # 流结束但未遇到 </think>:仅尝试生成要点(不回退输出 raw thinking)
  1549. if summary_enabled and in_think and not thinking_done and thinking_buf:
  1550. thinking_summary = await summarize_thinking_content(
  1551. user_question=message,
  1552. raw_thinking=thinking_buf,
  1553. final_answer="",
  1554. chat_service=qwen_service,
  1555. context="stream_chat_with_db_eof",
  1556. )
  1557. if thinking_summary:
  1558. prefix = f"思考过程:\n{thinking_summary}\n\n回答:\n"
  1559. full_response += prefix
  1560. escaped_prefix = prefix.replace('\n', '\\n')
  1561. yield f"data: {escaped_prefix}\n\n"
  1562. if pre_answer:
  1563. full_response += pre_answer
  1564. escaped_pre_answer = pre_answer.replace('\n', '\\n')
  1565. yield f"data: {escaped_pre_answer}\n\n"
  1566. # 9. 更新 AI 消息内容
  1567. if full_response:
  1568. now_ts = int(time.time())
  1569. db.query(AIMessage).filter(AIMessage.id == ai_msg.id).update(
  1570. {"content": full_response, "updated_at": now_ts}
  1571. )
  1572. db.query(AIConversation).filter(
  1573. AIConversation.id == conv_id,
  1574. AIConversation.user_id == user.user_id,
  1575. ).update(
  1576. {
  1577. "content": _build_conversation_preview(message, limit=100),
  1578. "business_type": data.business_type,
  1579. "exam_name": data.exam_name if data.business_type == 3 else "",
  1580. "updated_at": now_ts,
  1581. }
  1582. )
  1583. db.commit()
  1584. # 10. 结束标记
  1585. yield "data: [DONE]\n\n"
  1586. except Exception as e:
  1587. logger.error(f"[stream/chat-with-db] 处理异常: {e}")
  1588. yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
  1589. finally:
  1590. db.close()
  1591. return StreamingResponse(event_generator(), media_type="text/event-stream")
  1592. # ─────────────────────────────────────────────────────────────────────────
  1593. # 猜你想问
  1594. # ─────────────────────────────────────────────────────────────────────────
  1595. class GuessYouWantRequest(BaseModel):
  1596. ai_message_id: int
  1597. @router.post("/guess_you_want")
  1598. async def guess_you_want(
  1599. request: Request,
  1600. data: GuessYouWantRequest,
  1601. db: Session = Depends(get_db),
  1602. ):
  1603. """生成"猜你想问"的3个关联问题,保存到 AIMessage.guess_you_want"""
  1604. user = request.state.user
  1605. if not user:
  1606. return {"statusCode": 401, "msg": "未授权"}
  1607. try:
  1608. ai_msg = (
  1609. db.query(AIMessage)
  1610. .filter(AIMessage.id == data.ai_message_id, AIMessage.is_deleted == 0)
  1611. .first()
  1612. )
  1613. if not ai_msg:
  1614. return {"statusCode": 404, "msg": "消息不存在"}
  1615. # 使用prompt加载器加载猜你想问prompt
  1616. system_content = load_prompt(
  1617. "guess_questions",
  1618. currentContent=ai_msg.content[:500]
  1619. )
  1620. messages = [
  1621. {"role": "user", "content": system_content},
  1622. ]
  1623. response = await qwen_service.chat(messages)
  1624. try:
  1625. # 尝试从响应中提取 JSON
  1626. json_match = re.search(
  1627. r'\{[^{}]*"questions"[^{}]*\}', response, re.DOTALL)
  1628. if json_match:
  1629. response_json = json.loads(json_match.group())
  1630. else:
  1631. response_json = json.loads(response)
  1632. questions = response_json.get("questions", [])
  1633. except Exception:
  1634. lines = [l.strip() for l in response.split("\n") if l.strip()]
  1635. questions = []
  1636. for line in lines:
  1637. clean = line.lstrip("0123456789.-、 ").strip()
  1638. if clean and len(clean) > 5:
  1639. questions.append(clean)
  1640. if not questions:
  1641. questions = ["该话题的具体应用场景?", "有哪些注意事项?", "相关案例分析?"]
  1642. questions = _finalize_related_questions(
  1643. questions, ai_msg.content, limit=3)
  1644. guess_json = json.dumps({"questions": questions}, ensure_ascii=False)
  1645. db.query(AIMessage).filter(AIMessage.id == data.ai_message_id).update(
  1646. {"guess_you_want": guess_json, "updated_at": int(time.time())}
  1647. )
  1648. db.commit()
  1649. return {
  1650. "statusCode": 200,
  1651. "msg": "success",
  1652. "data": {"ai_message_id": data.ai_message_id, "questions": questions},
  1653. }
  1654. except Exception as e:
  1655. logger.error(f"[guess_you_want] 处理异常: {e}")
  1656. return {"statusCode": 500, "msg": f"处理失败: {str(e)}"}
  1657. # ─────────────────────────────────────────────────────────────────────────
  1658. # 在线搜索(Dify 工作流集成)
  1659. # ─────────────────────────────────────────────────────────────────────────
  1660. @router.get("/online_search")
  1661. async def online_search(question: str, request: Request, db: Session = Depends(get_db)):
  1662. """
  1663. 在线搜索
  1664. 流程:Qwen 提炼关键词 → Dify 工作流 → 返回摘要
  1665. """
  1666. user = request.state.user
  1667. if not user:
  1668. return {"statusCode": 401, "msg": "未授权"}
  1669. try:
  1670. keywords = await qwen_service.extract_keywords(question)
  1671. dify_config = getattr(settings, "dify", None)
  1672. if not dify_config or not getattr(dify_config, "workflow_url", None):
  1673. return {"statusCode": 500, "msg": "Dify 配置未设置"}
  1674. headers = {
  1675. "Authorization": f"Bearer {dify_config.auth_token}",
  1676. "Content-Type": "application/json",
  1677. }
  1678. payload = {
  1679. "workflow_id": dify_config.workflow_id,
  1680. "inputs": {
  1681. "keywords": keywords,
  1682. "num": 5, # 搜索结果数量
  1683. "max_text_len": 4000 # 最大文本长度
  1684. },
  1685. "response_mode": "blocking",
  1686. "user": getattr(user, "account", str(user.user_id)),
  1687. }
  1688. async with httpx.AsyncClient(timeout=30.0) as client:
  1689. resp = await client.post(dify_config.workflow_url, headers=headers, json=payload)
  1690. if resp.status_code != 200:
  1691. logger.error(
  1692. f"[online_search] Dify 调用失败: {resp.status_code}, 响应: {resp.text}")
  1693. return {"statusCode": 500, "msg": f"搜索服务异常: {resp.status_code}"}
  1694. result = resp.json()
  1695. search_text = result.get("data", {}).get(
  1696. "outputs", {}).get("text", "")
  1697. return {
  1698. "statusCode": 200,
  1699. "msg": "success",
  1700. "data": {"keywords": keywords, "result": search_text},
  1701. }
  1702. except Exception as e:
  1703. logger.error(f"[online_search] 处理异常: {e}")
  1704. return {"statusCode": 500, "msg": f"搜索失败: {str(e)}"}
  1705. class SaveOnlineSearchResultRequest(BaseModel):
  1706. ai_message_id: int
  1707. search_result: str
  1708. @router.post("/save_online_search_result")
  1709. async def save_online_search_result(
  1710. request: Request,
  1711. data: SaveOnlineSearchResultRequest,
  1712. db: Session = Depends(get_db),
  1713. ):
  1714. """保存联网搜索结果到 AIMessage.search_source"""
  1715. user = request.state.user
  1716. if not user:
  1717. return {"statusCode": 401, "msg": "未授权"}
  1718. try:
  1719. db.query(AIMessage).filter(AIMessage.id == data.ai_message_id).update(
  1720. {"search_source": data.search_result,
  1721. "updated_at": int(time.time())}
  1722. )
  1723. db.commit()
  1724. return {"statusCode": 200, "msg": "保存成功"}
  1725. except Exception as e:
  1726. logger.error(f"[save_online_search_result] 处理异常: {e}")
  1727. return {"statusCode": 500, "msg": f"保存失败: {str(e)}"}
  1728. # ─────────────────────────────────────────────────────────────────────────
  1729. # 意图识别独立接口
  1730. # ─────────────────────────────────────────────────────────────────────────
  1731. class IntentRecognitionRequest(BaseModel):
  1732. message: str
  1733. save_to_db: bool = False
  1734. ai_conversation_id: int = 0
  1735. @router.post("/intent_recognition")
  1736. async def intent_recognition(
  1737. request: Request,
  1738. data: IntentRecognitionRequest,
  1739. db: Session = Depends(get_db),
  1740. ):
  1741. """独立意图识别接口;若为 greeting/faq 且 save_to_db=True 则直接存 DB"""
  1742. user = request.state.user
  1743. if not user:
  1744. return {"statusCode": 401, "msg": "未授权"}
  1745. try:
  1746. intent_result = await qwen_service.intent_recognition(data.message)
  1747. intent_type = ""
  1748. response_text = ""
  1749. if isinstance(intent_result, dict):
  1750. intent_type = (
  1751. intent_result.get("intent_type") or intent_result.get(
  1752. "intent") or ""
  1753. ).lower()
  1754. response_text = intent_result.get("response", "")
  1755. if data.save_to_db and intent_type in ("greeting", "问候", "faq", "常见问题"):
  1756. if data.ai_conversation_id == 0:
  1757. conversation = AIConversation(
  1758. user_id=user.user_id,
  1759. content=data.message[:100],
  1760. business_type=0,
  1761. created_at=int(time.time()),
  1762. updated_at=int(time.time()),
  1763. is_deleted=0,
  1764. )
  1765. db.add(conversation)
  1766. db.commit()
  1767. db.refresh(conversation)
  1768. conv_id = conversation.id
  1769. else:
  1770. conv_id = data.ai_conversation_id
  1771. user_msg = AIMessage(
  1772. ai_conversation_id=conv_id,
  1773. user_id=user.user_id,
  1774. type="user",
  1775. content=data.message,
  1776. created_at=int(time.time()),
  1777. updated_at=int(time.time()),
  1778. is_deleted=0,
  1779. )
  1780. db.add(user_msg)
  1781. db.commit()
  1782. ai_msg = AIMessage(
  1783. ai_conversation_id=conv_id,
  1784. user_id=user.user_id,
  1785. type="ai",
  1786. content=response_text,
  1787. prev_user_id=user_msg.id,
  1788. created_at=int(time.time()),
  1789. updated_at=int(time.time()),
  1790. is_deleted=0,
  1791. )
  1792. db.add(ai_msg)
  1793. db.commit()
  1794. db.refresh(ai_msg)
  1795. return {
  1796. "statusCode": 200,
  1797. "msg": "success",
  1798. "data": {
  1799. "intent_type": intent_type,
  1800. "response": response_text,
  1801. "ai_conversation_id": conv_id,
  1802. "ai_message_id": ai_msg.id,
  1803. "saved_to_db": True,
  1804. },
  1805. }
  1806. return {
  1807. "statusCode": 200,
  1808. "msg": "success",
  1809. "data": {
  1810. "intent_type": intent_type,
  1811. "response": response_text,
  1812. "saved_to_db": False,
  1813. },
  1814. }
  1815. except Exception as e:
  1816. logger.error(f"[intent_recognition] 处理异常: {e}")
  1817. return {"statusCode": 500, "msg": f"处理失败: {str(e)}"}
  1818. # ─────────────────────────────────────────────────────────────────────────
  1819. # 获取用户推荐问题(模糊查询 QA / RecommendQuestion 表)
  1820. # ─────────────────────────────────────────────────────────────────────────
  1821. @router.get("/get_user_recommend_question")
  1822. async def get_user_recommend_question(
  1823. keyword: str = "",
  1824. limit: int = 10,
  1825. db: Session = Depends(get_db),
  1826. ):
  1827. """获取推荐问题(支持模糊查询)"""
  1828. try:
  1829. query = db.query(RecommendQuestion).filter(
  1830. RecommendQuestion.is_deleted == 0)
  1831. if keyword:
  1832. query = query.filter(
  1833. RecommendQuestion.question.like(f"%{keyword}%"))
  1834. questions = query.order_by(
  1835. RecommendQuestion.id.desc()).limit(limit).all()
  1836. return {
  1837. "statusCode": 200,
  1838. "msg": "success",
  1839. "data": [
  1840. {"id": q.id, "question": q.question, "created_at": q.created_at}
  1841. for q in questions
  1842. ],
  1843. }
  1844. except Exception as e:
  1845. logger.error(f"[get_user_recommend_question] 处理异常: {e}")
  1846. return {"statusCode": 500, "msg": f"查询失败: {str(e)}"}
  1847. # ─────────────────────────────────────────────────────────────────────────
  1848. # PPT 大纲 / 文档编辑保存
  1849. # ─────────────────────────────────────────────────────────────────────────
  1850. class SavePPTOutlineRequest(BaseModel):
  1851. ai_message_id: int
  1852. content: str
  1853. @router.post("/save_ppt_outline")
  1854. async def save_ppt_outline(
  1855. request: Request,
  1856. data: SavePPTOutlineRequest,
  1857. db: Session = Depends(get_db),
  1858. ):
  1859. """更新 AIMessage.content 保存 PPT 大纲内容"""
  1860. user = request.state.user
  1861. if not user:
  1862. return {"statusCode": 401, "msg": "未授权"}
  1863. try:
  1864. db.query(AIMessage).filter(AIMessage.id == data.ai_message_id).update(
  1865. {"content": data.content, "updated_at": int(time.time())}
  1866. )
  1867. db.commit()
  1868. return {"statusCode": 200, "msg": "保存成功"}
  1869. except Exception as e:
  1870. logger.error(f"[save_ppt_outline] 处理异常: {e}")
  1871. return {"statusCode": 500, "msg": f"保存失败: {str(e)}"}
  1872. class SaveEditDocumentRequest(BaseModel):
  1873. ai_message_id: int
  1874. content: str
  1875. @router.post("/save_edit_document")
  1876. async def save_edit_document(
  1877. request: Request,
  1878. data: SaveEditDocumentRequest,
  1879. db: Session = Depends(get_db),
  1880. ):
  1881. """更新 ai 类型 AIMessage.content(AI写作编辑保存)"""
  1882. user = request.state.user
  1883. if not user:
  1884. return {"statusCode": 401, "msg": "未授权"}
  1885. try:
  1886. db.query(AIMessage).filter(
  1887. AIMessage.id == data.ai_message_id,
  1888. AIMessage.type == "ai",
  1889. ).update({"content": data.content, "updated_at": int(time.time())})
  1890. db.commit()
  1891. return {"statusCode": 200, "msg": "保存成功"}
  1892. except Exception as e:
  1893. logger.error(f"[save_edit_document] 处理异常: {e}")
  1894. return {"statusCode": 500, "msg": f"保存失败: {str(e)}"}