| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208 |
- from fastapi import APIRouter, Depends, Request
- from fastapi.responses import StreamingResponse, JSONResponse
- from sqlalchemy.orm import Session
- from pydantic import BaseModel
- from typing import Optional
- from database import get_db, SessionLocal
- from models.chat import AIConversation, AIMessage
- from models.total import RecommendQuestion
- from utils.config import settings
- from utils.logger import logger
- from services.qwen_service import qwen_service
- from services.deepseek_service import deepseek_service
- from utils.prompt_loader import load_prompt
- from utils.thinking_summary import split_thinking_and_answer, summarize_thinking_content
- import time
- import json
- import httpx
- import re
- router = APIRouter()
- def _build_conversation_preview(content: str, limit: int = 50) -> str:
- content = (content or "").strip()
- if len(content) <= limit:
- return content
- return content[:limit] + "..."
- def _to_frontend_timestamp(timestamp: Optional[int]) -> Optional[int]:
- if not timestamp:
- return None
- return timestamp if timestamp >= 10**12 else timestamp * 1000
- def _build_conversation_title(conversation: AIConversation) -> str:
- if conversation.business_type == 3 and (conversation.exam_name or "").strip():
- return conversation.exam_name.strip()
- return _build_conversation_preview(conversation.content or "", limit=30)
- def _extract_json_object_from_index(source: str, start_idx: int) -> str:
- if start_idx < 0 or start_idx >= len(source) or source[start_idx] != "{":
- return ""
- depth = 0
- in_string = False
- escaped = False
- for idx in range(start_idx, len(source)):
- ch = source[idx]
- if escaped:
- escaped = False
- continue
- if in_string:
- if ch == "\\":
- escaped = True
- elif ch == '"':
- in_string = False
- continue
- if ch == '"':
- in_string = True
- continue
- if ch == "{":
- depth += 1
- elif ch == "}":
- depth -= 1
- if depth == 0:
- return source[start_idx: idx + 1]
- return ""
- def _extract_balanced_json_objects(text: str) -> list[str]:
- source = (text or "").strip()
- if not source:
- return []
- objects = []
- seen = set()
- for idx, ch in enumerate(source):
- if ch != "{":
- continue
- candidate = _extract_json_object_from_index(source, idx)
- if candidate and candidate not in seen:
- objects.append(candidate)
- seen.add(candidate)
- return objects
- def _extract_trailing_json_candidates(text: str) -> list[str]:
- source = (text or "").strip()
- if not source:
- return []
- candidates = []
- seen = set()
- line_start_indexes = [
- match.start()
- for match in re.finditer(r"(?m)^[ \t]*\{", source)
- ]
- for start_idx in reversed(line_start_indexes):
- candidate = source[start_idx:].strip()
- if candidate and candidate not in seen:
- candidates.append(candidate)
- seen.add(candidate)
- return candidates
- def _extract_explicit_answer_segment(text: str) -> str:
- source = (text or "").strip()
- if not source:
- return ""
- markers = (
- "final answer:",
- "final output:",
- "answer:",
- "output:",
- "json:",
- )
- lowered = source.lower()
- for marker in markers:
- idx = lowered.rfind(marker)
- if idx >= 0:
- candidate = source[idx + len(marker):].strip()
- if candidate:
- return candidate
- return ""
- def _extract_brace_sliced_candidates(text: str) -> list[str]:
- source = (text or "").strip()
- if not source:
- return []
- candidates = []
- seen = set()
- first_brace = source.find("{")
- last_brace = source.rfind("}")
- if first_brace >= 0 and last_brace > first_brace:
- candidate = source[first_brace:last_brace + 1].strip()
- if candidate and candidate not in seen:
- candidates.append(candidate)
- seen.add(candidate)
- return candidates
- def _looks_like_exam_payload(payload: object) -> bool:
- if not isinstance(payload, dict):
- return False
- questions = payload.get("questions")
- return any(
- key in payload
- for key in (
- "singleChoice",
- "single_choice",
- "单选题",
- "judge",
- "判断题",
- "multiple",
- "multiple_choice",
- "multipleChoice",
- "多选题",
- "short",
- "short_answer",
- "shortAnswer",
- "简答题",
- )
- ) or (
- isinstance(questions, dict)
- and any(
- key in questions
- for key in (
- "singleChoice",
- "single_choice",
- "单选题",
- "judge",
- "判断题",
- "multiple",
- "multiple_choice",
- "multipleChoice",
- "多选题",
- "short",
- "short_answer",
- "shortAnswer",
- "简答题",
- )
- )
- )
- def _score_exam_payload_candidate(payload: object) -> int:
- if not isinstance(payload, dict):
- return 0
- score = 0
- questions = payload.get("questions") if isinstance(
- payload.get("questions"), dict) else {}
- strong_keys = (
- "singleChoice",
- "single_choice",
- "单选题",
- "judge",
- "判断题",
- "multiple",
- "multiple_choice",
- "multipleChoice",
- "多选题",
- "short",
- "short_answer",
- "shortAnswer",
- "简答题",
- )
- weak_keys = (
- "title",
- "exam_name",
- "examTitle",
- "试卷标题",
- "总分",
- "totalScore",
- "totalQuestions",
- )
- score += sum(10 for key in strong_keys if key in payload)
- score += sum(8 for key in strong_keys if key in questions)
- score += sum(2 for key in weak_keys if key in payload)
- section_candidates = []
- for _, value in payload.items():
- if isinstance(value, dict):
- section_candidates.append(value)
- section_candidates.extend(
- value for value in questions.values() if isinstance(value, dict))
- for section in section_candidates:
- if "questions" in section and isinstance(section.get("questions"), list):
- score += 6
- question_list = section.get("questions") or []
- if question_list and isinstance(question_list[0], dict):
- first_question = question_list[0]
- if any(k in first_question for k in ("text", "question_text", "question", "title", "content", "题干", "题目")):
- score += 4
- if "options" in first_question:
- score += 3
- if any(k in first_question for k in ("answer", "answers", "correct_answer", "correct_answers", "答案", "正确答案")):
- score += 3
- if any(k in first_question for k in ("analysis", "explanation", "解析")):
- score += 2
- if any(k in section for k in ("count", "question_count", "数量")):
- score += 2
- if any(k in section for k in ("scorePerQuestion", "score_per_question", "每题分值")):
- score += 1
- return score
- def _escape_inner_quotes_in_json(text: str) -> str:
- chars = []
- in_string = False
- escaped = False
- for idx, ch in enumerate(text):
- if not in_string:
- chars.append(ch)
- if ch == '"':
- in_string = True
- escaped = False
- continue
- if escaped:
- chars.append(ch)
- escaped = False
- continue
- if ch == "\\":
- chars.append(ch)
- escaped = True
- continue
- if ch == '"':
- next_non_space = ""
- for next_idx in range(idx + 1, len(text)):
- if not text[next_idx].isspace():
- next_non_space = text[next_idx]
- break
- if next_non_space in {",", "}", "]", ":"}:
- chars.append(ch)
- in_string = False
- else:
- chars.append('\\"')
- continue
- chars.append(ch)
- return "".join(chars)
- def _try_parse_exam_json(candidate: str) -> Optional[dict]:
- text = (candidate or "").strip()
- if not text:
- return None
- text = (
- text.replace("\ufeff", "")
- .replace("```json", "")
- .replace("```JSON", "")
- .replace("```", "")
- .replace("“", '"')
- .replace("”", '"')
- ).strip()
- try:
- parsed = json.loads(text)
- except Exception:
- repaired_text = _escape_inner_quotes_in_json(text)
- repaired_text = re.sub(r",\s*([}\]])", r"\1", repaired_text)
- try:
- parsed = json.loads(repaired_text)
- except Exception:
- return None
- return parsed if _looks_like_exam_payload(parsed) else None
- def _sanitize_exam_response(raw_response: str) -> str:
- """考试工坊只向前端/数据库透传可 JSON.parse 的试卷 JSON。"""
- raw_text = (raw_response or "").strip()
- if not raw_text:
- return ""
- _, answer = split_thinking_and_answer(raw_text)
- explicit_answer = _extract_explicit_answer_segment(raw_text)
- for candidate in (answer, explicit_answer, raw_text):
- parsed = _try_parse_exam_json(candidate)
- if parsed:
- return json.dumps(parsed, ensure_ascii=False)
- parsed_candidates = []
- for candidate in _extract_balanced_json_objects(raw_text):
- parsed = _try_parse_exam_json(candidate)
- if parsed:
- parsed_candidates.append((parsed, candidate))
- for candidate in _extract_trailing_json_candidates(raw_text):
- parsed = _try_parse_exam_json(candidate)
- if parsed:
- parsed_candidates.append((parsed, candidate))
- for candidate in _extract_brace_sliced_candidates(raw_text):
- parsed = _try_parse_exam_json(candidate)
- if parsed:
- parsed_candidates.append((parsed, candidate))
- if parsed_candidates:
- parsed_candidates.sort(
- key=lambda item: (
- _score_exam_payload_candidate(item[0]),
- len(json.dumps(item[0], ensure_ascii=False)),
- ),
- reverse=True,
- )
- best_payload, best_raw_candidate = parsed_candidates[0]
- if _score_exam_payload_candidate(best_payload) > 0:
- return json.dumps(best_payload, ensure_ascii=False)
- logger.warning(
- "[exam] 已提取到JSON对象但试卷特征较弱,选择最大候选兜底: score=%s snippet=%s",
- _score_exam_payload_candidate(best_payload),
- (best_raw_candidate or "")[:200],
- )
- return json.dumps(best_payload, ensure_ascii=False)
- logger.warning("[exam] 未能从模型响应中提取试卷 JSON,保留原始响应供前端兜底解析")
- return raw_text
- def _normalize_related_question(question: str) -> str:
- if not isinstance(question, str):
- return ""
- text = question.strip().strip('"').strip("'")
- text = re.sub(r"^[0-9]+[\.\)\]、]\s*", "", text)
- text = re.sub(r"^[-*]\s*", "", text)
- return text.strip()
- def _is_placeholder_related_question(question: str) -> bool:
- normalized = _normalize_related_question(question).lower()
- if not normalized:
- return True
- placeholder_patterns = (
- r"^q\s*\d+$",
- r"^question\s*\d+$",
- r"^questions?\s*\d+$",
- r"^问题\s*\d+$",
- r"^相关问题\s*\d+$",
- r"^推荐问题\s*\d+$",
- r"^更多相关问题$",
- r"^更多问题$",
- )
- return any(re.fullmatch(pattern, normalized) for pattern in placeholder_patterns)
- def _contains_chinese(text: str) -> bool:
- return any("\u4e00" <= char <= "\u9fff" for char in text or "")
- def _is_invalid_related_question(question: str) -> bool:
- normalized = _normalize_related_question(question)
- if (
- not normalized
- or len(normalized) < 4
- or _is_placeholder_related_question(normalized)
- or not _contains_chinese(normalized)
- ):
- return True
- lowered = normalized.lower()
- blocked_keywords = (
- "thinking process",
- "analyze the request",
- "role:",
- "**role",
- "professional question recommendation",
- "infrastructure construction technology",
- "output format",
- "json",
- "prompt",
- "system",
- "assistant",
- "角色定义",
- "任务目标",
- "输入内容",
- "生成要求",
- "输出格式",
- "开始生成",
- )
- return any(keyword in lowered for keyword in blocked_keywords)
- def _extract_related_question_topic(content: str) -> str:
- if not content:
- return "当前话题"
- text = re.sub(r"<[^>]+>", " ", str(content))
- text = re.sub(r"\s+", " ", text).strip()
- text = re.sub(
- r"^(好的[!!,, ]*|我理解您提出的问题[,, ]*|这个问题[,, ]*|总的来说[::,, ]*)+",
- "",
- text,
- )
- pattern = re.search(
- r"(?:主要围绕|围绕|关于|针对|聚焦)([^。!?\n,,;;]{4,32})",
- text,
- )
- if pattern:
- topic = pattern.group(1).strip("“”\"' ::,,")
- if topic:
- return topic
- sentence = re.split(r"[。!?\n]", text, maxsplit=1)[0].strip("“”\"' ::,,")
- if sentence:
- return sentence[:24]
- return "当前话题"
- def _build_related_question_fallbacks(content: str) -> list[str]:
- topic = _extract_related_question_topic(content)
- return [
- f"{topic}在现场实施时需要重点关注哪些风险点?",
- f"{topic}相关的方案编制、审批和验收要求有哪些?",
- f"针对{topic},日常检查和监测应抓住哪些关键指标?",
- ]
- def _finalize_related_questions(questions: list, content: str, limit: int = 3) -> list[str]:
- cleaned_questions = []
- seen = set()
- for question in questions or []:
- normalized = _normalize_related_question(question)
- lowered = normalized.lower()
- if (
- _is_invalid_related_question(normalized)
- or lowered in seen
- ):
- continue
- cleaned_questions.append(normalized)
- seen.add(lowered)
- if len(cleaned_questions) == limit:
- return cleaned_questions
- for fallback in _build_related_question_fallbacks(content):
- lowered = fallback.lower()
- if lowered in seen:
- continue
- cleaned_questions.append(fallback)
- seen.add(lowered)
- if len(cleaned_questions) == limit:
- break
- return cleaned_questions[:limit]
- def _refresh_conversation_snapshot(db: Session, conversation_id: int, user_id: int) -> None:
- latest_message = (
- db.query(AIMessage)
- .filter(
- AIMessage.ai_conversation_id == conversation_id,
- AIMessage.user_id == user_id,
- AIMessage.is_deleted == 0,
- )
- .order_by(AIMessage.id.desc())
- .first()
- )
- if not latest_message:
- db.query(AIConversation).filter(
- AIConversation.id == conversation_id,
- AIConversation.user_id == user_id,
- ).update({"is_deleted": 1, "updated_at": int(time.time())})
- return
- latest_user_message = (
- db.query(AIMessage)
- .filter(
- AIMessage.ai_conversation_id == conversation_id,
- AIMessage.user_id == user_id,
- AIMessage.type == "user",
- AIMessage.is_deleted == 0,
- )
- .order_by(AIMessage.id.desc())
- .first()
- )
- preview_source = (
- latest_user_message.content
- if latest_user_message and latest_user_message.content
- else latest_message.content
- )
- preview_content = _build_conversation_preview(
- preview_source or "", limit=100)
- db.query(AIConversation).filter(
- AIConversation.id == conversation_id,
- AIConversation.user_id == user_id,
- ).update(
- {
- "content": preview_content or " ",
- "updated_at": int(time.time()),
- }
- )
- # ─────────────────────────────────────────────────────────────────────────
- # 辅助函数
- # ─────────────────────────────────────────────────────────────────────────
- async def _rag_search(message: str, top_k: int = 5) -> str:
- """调用 search API 做 RAG 检索,返回上下文文本"""
- try:
- search_cfg = getattr(settings, 'search', None)
- if not search_cfg or not hasattr(search_cfg, 'api_url'):
- return ""
- search_url = search_cfg.api_url
- if not search_url:
- return ""
- async with httpx.AsyncClient(timeout=10.0) as client:
- resp = await client.post(
- search_url,
- json={"query": message, "n_results": top_k},
- )
- if resp.status_code == 200:
- data = resp.json()
- docs = data.get("results") or data.get("documents") or []
- return "\n\n".join(
- d.get("content") or d.get("text") or str(d)
- for d in docs[:top_k]
- if d.get("content") or d.get("text")
- )
- except Exception as e:
- logger.warning(f"[RAG] 检索失败(可忽略): {e}")
- return ""
- SAFETY_TRAINING_PLAN_SYSTEM_PROMPT = """
- 你是安全培训需求整理助手。请把用户的自然语言输入整理成安全培训PPT大纲生成任务。
- 规则:
- 1. 只输出一个 JSON 对象,不要输出 Markdown、解释或额外文字。
- 2. 即使用户说“通知”“材料”“文档”,也必须理解为安全培训模块中的 PPT 大纲需求,不要切换到其他文档生成任务。
- 3. 如果字段缺失,请根据安全培训场景合理补全,但不要编造具体制度编号、人员姓名或不存在的事实。
- 4. template 字段用于选择大纲模板,默认填“标准安全培训PPT大纲”。
- 5. content_focus 至少给出 3 个要点。
- JSON 字段:
- {
- "topic": "培训主题",
- "template": "模板名称",
- "content_focus": ["内容要点1", "内容要点2", "内容要点3"],
- "audience": "参训对象",
- "time": "培训时间",
- "location": "培训地点",
- "goal": "培训目标",
- "notes": "其他要求",
- "normalized_request": "归一化后的安全培训PPT大纲生成需求"
- }
- """
- def _extract_tag_value(message: str, tag: str) -> str:
- match = re.search(fr"<{tag}>(.*?)</{tag}>", message or "", re.DOTALL)
- return match.group(1).strip() if match else ""
- def _strip_document_tags(message: str) -> str:
- text = message or ""
- for tag in ("word", "filename", "filesize"):
- text = re.sub(fr"<{tag}>.*?</{tag}>", " ", text, flags=re.DOTALL)
- return re.sub(r"\s+", " ", text).strip()
- def _extract_safety_training_request_payload(message: str) -> dict:
- return {
- "document_content": _extract_tag_value(message, "word"),
- "filename": _extract_tag_value(message, "filename"),
- "filesize": _extract_tag_value(message, "filesize"),
- "request": _strip_document_tags(message),
- }
- def _clean_safety_training_topic(message: str) -> str:
- request_text = _extract_safety_training_request_payload(message)["request"]
- first_clause = re.split(r"[,。;;,\n]", request_text, maxsplit=1)[0].strip()
- topic = first_clause or request_text or "安全培训"
- for token in ("请", "帮我", "帮忙", "生成", "制作", "输出", "一份", "一个", "一下", "PPT大纲", "ppt大纲", "大纲", "通知", "文档", "材料"):
- topic = topic.replace(token, "")
- topic = re.sub(r"\s+", "", topic).strip(" ::,,。;;")
- if not topic:
- topic = "安全培训"
- if "培训" not in topic:
- topic = f"{topic}安全培训"
- return topic
- def _parse_json_object(text: str) -> dict:
- if not text:
- return {}
- cleaned = re.sub(r"```(?:json)?\s*", "", str(text)
- ).replace("```", "").strip()
- match = re.search(r"\{.*\}", cleaned, re.DOTALL)
- if not match:
- return {}
- try:
- parsed = json.loads(match.group(0))
- return parsed if isinstance(parsed, dict) else {}
- except json.JSONDecodeError:
- return {}
- def _build_fallback_safety_training_plan(message: str) -> dict:
- topic = _clean_safety_training_topic(message)
- payload = _extract_safety_training_request_payload(message)
- return {
- "topic": topic,
- "template": "标准安全培训PPT大纲",
- "content_focus": ["安全生产责任", "现场风险识别", "安全意识提升", "培训纪律与行为规范"],
- "audience": "参训员工",
- "time": "",
- "location": "",
- "goal": "提升参训人员安全意识和施工现场风险防控能力",
- "notes": payload["request"],
- "normalized_request": f"围绕{topic}生成安全培训PPT大纲",
- }
- def _normalize_safety_training_plan(message: str, raw_plan: dict) -> dict:
- plan = _build_fallback_safety_training_plan(message)
- if not isinstance(raw_plan, dict):
- return plan
- for key in ("topic", "template", "audience", "time", "location", "goal", "notes", "normalized_request"):
- value = raw_plan.get(key)
- if isinstance(value, str) and value.strip():
- plan[key] = value.strip()
- focus = raw_plan.get("content_focus")
- if isinstance(focus, list):
- normalized_focus = [str(item).strip()
- for item in focus if str(item).strip()]
- if normalized_focus:
- plan["content_focus"] = normalized_focus
- elif isinstance(focus, str) and focus.strip():
- plan["content_focus"] = [item.strip()
- for item in re.split(r"[、,,;\n]", focus) if item.strip()]
- if "培训" not in plan["topic"]:
- plan["topic"] = f"{plan['topic']}安全培训"
- if "PPT大纲" not in plan["template"]:
- plan["template"] = f"{plan['template']}PPT大纲"
- return plan
- def _build_safety_training_generation_message(message: str, plan: dict) -> str:
- payload = _extract_safety_training_request_payload(message)
- focus_text = "、".join(plan.get("content_focus") or [])
- lines = [
- "输出类型:安全培训PPT大纲",
- "请基于以下结构化需求生成安全培训PPT大纲,不要生成通知正文,不要切换到其他文档生成任务。",
- f"主题:{plan.get('topic') or '安全培训'}",
- f"模板:{plan.get('template') or '标准安全培训PPT大纲'}",
- f"内容要点:{focus_text or '安全生产责任、风险识别、应急处置、安全意识提升'}",
- f"参训对象:{plan.get('audience') or '参训员工'}",
- f"培训时间:{plan.get('time') or '未指定'}",
- f"培训地点:{plan.get('location') or '未指定'}",
- f"培训目标:{plan.get('goal') or '提升参训人员安全意识和风险防控能力'}",
- f"其他要求:{plan.get('notes') or '无'}",
- f"归一化需求:{plan.get('normalized_request') or ''}",
- f"原始需求:{payload['request'] or message}",
- ]
- if payload["filename"] or payload["document_content"]:
- lines.extend([
- f"上传文档名称:{payload['filename'] or '未命名文档'}",
- f"上传文档大小:{payload['filesize'] or '未知'}",
- "上传文档内容:",
- payload["document_content"] or "无",
- ])
- return "\n".join(lines)
- async def _infer_safety_training_plan(message: str) -> dict:
- payload = _extract_safety_training_request_payload(message)
- planning_input = payload["request"] or message
- if payload["document_content"]:
- planning_input = (
- f"{planning_input}\n\n"
- f"上传文档名称:{payload['filename'] or '未命名文档'}\n"
- f"上传文档内容摘要:{payload['document_content'][:3000]}"
- )
- try:
- response = await qwen_service.chat([
- {"role": "system", "content": SAFETY_TRAINING_PLAN_SYSTEM_PROMPT},
- {"role": "user", "content": planning_input},
- ])
- return _normalize_safety_training_plan(message, _parse_json_object(response))
- except Exception as e:
- logger.warning(
- f"[safety_training] 需求整理失败,使用兜底结构: {type(e).__name__}: {e}")
- return _build_fallback_safety_training_plan(message)
- def _clean_ai_writing_response(content: str) -> str:
- text = str(content or "").strip()
- if not text:
- return ""
- text = re.sub(r"```(?:html)?\s*", "", text,
- flags=re.IGNORECASE).replace("```", "").strip()
- body_match = re.search(
- r"<body[^>]*>(.*?)</body>", text, re.IGNORECASE | re.DOTALL)
- if body_match:
- text = body_match.group(1).strip()
- first_content_tag = re.search(
- r"<(?:article|section|main|div|h[1-6]|p|table|ul|ol)\b",
- text,
- re.IGNORECASE,
- )
- if first_content_tag and text[:first_content_tag.start()].strip():
- text = text[first_content_tag.start():]
- cleanup_patterns = (
- r"<!DOCTYPE[^>]*>",
- r"<html[^>]*>",
- r"</html>",
- r"<head[^>]*>.*?</head>",
- r"<body[^>]*>",
- r"</body>",
- r"<style[^>]*>.*?</style>",
- r"<script[^>]*>.*?</script>",
- r"<meta[^>]*>",
- r"<title[^>]*>.*?</title>",
- )
- for pattern in cleanup_patterns:
- text = re.sub(pattern, "", text, flags=re.IGNORECASE | re.DOTALL)
- return text.strip()
- async def _generate_ai_writing_response(message: str) -> str:
- rag_context = await _rag_search(message, top_k=10)
- system_content = load_prompt(
- "document_writing",
- userMessage=message,
- contextJSON=rag_context if rag_context else "暂无相关知识库内容",
- )
- messages = [
- {"role": "system", "content": system_content},
- {
- "role": "user",
- "content": (
- "请根据上面的写作规范和我的原始需求,直接生成可放入富文本编辑器的公文正文 HTML 片段。"
- "不要输出道歉、解释、DOCTYPE、html、head、body、style 或 script 标签。\n\n"
- f"原始需求:\n{message}"
- ),
- },
- ]
- raw_response = await deepseek_service.chat(messages)
- raw_thinking, raw_answer = split_thinking_and_answer(raw_response)
- answer_text = _clean_ai_writing_response(raw_answer or raw_response)
- # AI写作输出纯HTML文档内容,不附加思考过程(避免混入纯文本破坏HTML结构)
- return answer_text
- async def _generate_ppt_outline_response(message: str) -> str:
- training_plan = await _infer_safety_training_plan(message)
- generation_message = _build_safety_training_generation_message(
- message, training_plan)
- rag_context = await _rag_search(generation_message, top_k=10)
- system_content = load_prompt(
- "ppt_outline",
- userMessage=generation_message,
- contextJSON=rag_context if rag_context else "暂无相关知识库内容",
- )
- messages = [
- {"role": "system", "content": system_content},
- {"role": "user", "content": "请直接输出安全培训PPT大纲正文,从标题开始,不要解释提示词或安全规则。"},
- ]
- raw_response = await qwen_service.chat(messages)
- raw_thinking, raw_answer = split_thinking_and_answer(raw_response)
- answer_text = raw_answer or raw_response
- if raw_thinking:
- thinking_summary = await summarize_thinking_content(
- user_question=message,
- raw_thinking=raw_thinking,
- final_answer=answer_text,
- chat_service=qwen_service,
- context="ppt_outline",
- )
- return (
- f"思考过程:\n{thinking_summary}\n\n回答:\n{answer_text}"
- if thinking_summary
- else answer_text
- )
- return answer_text
- def _persist_message_pair(db: Session, conv_id: int, user, user_content: str, ai_content: str):
- now_ts = int(time.time())
- user_message = AIMessage(
- ai_conversation_id=conv_id,
- user_id=user.user_id,
- type="user",
- content=user_content,
- created_at=now_ts,
- updated_at=now_ts,
- is_deleted=0,
- )
- db.add(user_message)
- db.commit()
- db.refresh(user_message)
- ai_message = AIMessage(
- ai_conversation_id=conv_id,
- user_id=user.user_id,
- type="ai",
- content=ai_content,
- prev_user_id=user_message.id,
- created_at=now_ts,
- updated_at=now_ts,
- is_deleted=0,
- )
- db.add(ai_message)
- db.commit()
- db.refresh(ai_message)
- return user_message, ai_message
- def _build_history_messages(conv_id: int, limit: int = 10) -> list:
- """从数据库读取最近对话历史,构建 messages 列表"""
- db = SessionLocal()
- try:
- msgs = (
- db.query(AIMessage)
- .filter(AIMessage.ai_conversation_id == conv_id, AIMessage.is_deleted == 0)
- .order_by(AIMessage.id.desc())
- .limit(limit)
- .all()
- )
- msgs.reverse()
- result = []
- for m in msgs:
- role = "user" if m.type == "user" else "assistant"
- if m.content:
- result.append({"role": role, "content": m.content})
- return result
- finally:
- db.close()
- # ─────────────────────────────────────────────────────────────────────────
- # 非流式接口
- # ─────────────────────────────────────────────────────────────────────────
- class SendMessageRequest(BaseModel):
- message: str
- conversation_id: Optional[int] = None
- ai_conversation_id: Optional[int] = None
- business_type: int = 0 # 0=AI问答, 1=PPT大纲, 2=AI写作, 3=考试工坊
- exam_name: str = ""
- ai_message_id: int = 0
- @router.post("/send_deepseek_message")
- async def send_deepseek_message(
- request: Request,
- data: SendMessageRequest,
- db: Session = Depends(get_db),
- ):
- """
- 发送消息(非流式)
- 支持多种业务类型:
- - 0: AI问答(意图识别 + RAG)
- - 1: PPT大纲生成
- - 2: AI写作
- - 3: 考试工坊
- """
- user = request.state.user
- if not user:
- return {"statusCode": 401, "msg": "未授权"}
- try:
- message = data.message.strip()
- if not message:
- return {"statusCode": 400, "msg": "消息不能为空"}
- conversation_id = data.conversation_id or data.ai_conversation_id
- # 创建或获取对话
- if not conversation_id:
- conversation = AIConversation(
- user_id=user.user_id,
- content=message[:100],
- business_type=data.business_type,
- exam_name=data.exam_name if data.business_type == 3 else "",
- created_at=int(time.time()),
- updated_at=int(time.time()),
- is_deleted=0,
- )
- db.add(conversation)
- db.commit()
- db.refresh(conversation)
- conv_id = conversation.id
- else:
- conv_id = conversation_id
- db.query(AIConversation).filter(
- AIConversation.id == conv_id,
- AIConversation.user_id == user.user_id,
- AIConversation.is_deleted == 0,
- ).update({
- "content": message[:100],
- "business_type": data.business_type,
- "exam_name": data.exam_name if data.business_type == 3 else "",
- "updated_at": int(time.time()),
- })
- db.commit()
- response_text = ""
- ai_message_id = 0
- if data.business_type == 0:
- # AI问答:意图识别 + RAG
- try:
- intent_result = await qwen_service.intent_recognition(message)
- intent_type = ""
- if isinstance(intent_result, dict):
- intent_type = (
- intent_result.get("intent_type") or intent_result.get(
- "intent") or ""
- ).lower()
- rag_context = ""
- if intent_type in ("query_knowledge_base", "知识库查询", "技术咨询"):
- rag_context = await _rag_search(message, top_k=10)
- # 使用prompt加载器加载最终回答prompt
- system_content = load_prompt(
- "final_answer",
- userMessage=message,
- contextJSON=rag_context if rag_context else "暂无相关知识库内容"
- )
- messages = [
- {"role": "user", "content": system_content},
- ]
- qwen_response = await qwen_service.chat(messages)
- raw_thinking, raw_answer = split_thinking_and_answer(
- qwen_response)
- answer_source = raw_answer or qwen_response
- # 兼容模型直接返回 JSON 的场景
- answer_text = answer_source
- try:
- if isinstance(answer_source, str) and answer_source.strip().startswith("{"):
- response_json = json.loads(answer_source)
- answer_text = response_json.get(
- "natural_language_answer", answer_source
- )
- except Exception:
- answer_text = answer_source
- if raw_thinking:
- thinking_summary = await summarize_thinking_content(
- user_question=message,
- raw_thinking=raw_thinking,
- final_answer=answer_text,
- chat_service=qwen_service,
- context="send_message",
- )
- response_text = (
- f"思考过程:\n{thinking_summary}\n\n回答:\n{answer_text}"
- if thinking_summary
- else answer_text
- )
- else:
- response_text = answer_text
- except Exception as e:
- error_detail = str(e).strip() if str(
- e).strip() else f"未知错误({type(e).__name__})"
- logger.error(
- f"[send_deepseek_message] AI问答异常: {type(e).__name__}: {error_detail}")
- response_text = f"处理失败: {error_detail}"
- elif data.business_type == 1:
- # PPT大纲生成
- try:
- response_text = await _generate_ppt_outline_response(message)
- _, ai_message = _persist_message_pair(
- db=db,
- conv_id=conv_id,
- user=user,
- user_content=message,
- ai_content=response_text,
- )
- ai_message_id = ai_message.id
- _refresh_conversation_snapshot(db, conv_id, user.user_id)
- db.commit()
- return {
- "statusCode": 200,
- "msg": "success",
- "data": {
- "conversation_id": conv_id,
- "ai_conversation_id": conv_id,
- "response": response_text,
- "reply": response_text,
- "content": response_text,
- "message": response_text,
- "ai_message_id": ai_message_id,
- "user_id": user.user_id,
- "business_type": data.business_type,
- },
- }
- except Exception as e:
- error_detail = str(e).strip() if str(
- e).strip() else f"未知错误({type(e).__name__})"
- logger.error(
- f"[send_deepseek_message] PPT大纲生成异常: {type(e).__name__}: {error_detail}")
- response_text = f"处理失败: {error_detail}"
- elif data.business_type == 2:
- # AI写作
- try:
- response_text = await _generate_ai_writing_response(message)
- _, ai_message = _persist_message_pair(
- db=db,
- conv_id=conv_id,
- user=user,
- user_content=message,
- ai_content=response_text,
- )
- ai_message_id = ai_message.id
- _refresh_conversation_snapshot(db, conv_id, user.user_id)
- db.commit()
- return {
- "statusCode": 200,
- "msg": "success",
- "data": {
- "conversation_id": conv_id,
- "ai_conversation_id": conv_id,
- "response": response_text,
- "reply": response_text,
- "content": response_text,
- "message": response_text,
- "ai_message_id": ai_message_id,
- "user_id": user.user_id,
- "business_type": data.business_type,
- },
- }
- except Exception as e:
- error_detail = str(e).strip() if str(
- e).strip() else f"未知错误({type(e).__name__})"
- logger.error(
- f"[send_deepseek_message] AI写作异常: {type(e).__name__}: {error_detail}")
- response_text = f"处理失败: {error_detail}"
- elif data.business_type == 3:
- # 考试工坊:生成题目
- try:
- system_content = (
- "你是一个专业的考试题目生成助手,专注于路桥隧轨施工安全领域。\n"
- "请根据用户需求生成专业的考试题目,包括单选题、多选题、判断题、简答题等。\n"
- "用户消息中已经包含考试标题、题型要求和出题依据内容,必须以其中的出题依据内容为核心生成题目,不能脱离依据内容自由发挥。\n"
- "题干、选项、答案和解析都要与出题依据内容中的知识点、专业术语、操作流程、规范要求或培训主题直接相关。\n"
- "输出必须是可直接 JSON.parse 的纯 JSON,不要包含 markdown 代码块、解释文字或额外前后缀。\n"
- "JSON 顶层结构必须包含 singleChoice、judge、multiple、short 四个字段。\n"
- "singleChoice.questions 和 multiple.questions 中每道题必须包含 text、options、answer、analysis。\n"
- "options 必须是数组,元素格式为 {\"key\":\"A\",\"text\":\"具体选项内容\"},禁止输出“选项A”这类占位文本。\n"
- "judge.questions 中每道题必须包含 text、answer、analysis。\n"
- "short.questions 中每道题必须包含 text、outline,其中 outline 至少包含 keyFactors。\n"
- "所有题目内容、选项内容、答案和解析都要结合用户给出的工程类型、题型数量、分值和课件内容具体生成。"
- )
- messages = [
- {"role": "system", "content": system_content},
- {"role": "user", "content": message},
- ]
- raw_response_text = await qwen_service.chat(messages)
- response_text = _sanitize_exam_response(raw_response_text)
- now_ts = int(time.time())
- user_message = AIMessage(
- ai_conversation_id=conv_id,
- user_id=user.user_id,
- type="user",
- content=message,
- created_at=now_ts,
- updated_at=now_ts,
- is_deleted=0,
- )
- db.add(user_message)
- db.commit()
- db.refresh(user_message)
- ai_message = AIMessage(
- ai_conversation_id=conv_id,
- user_id=user.user_id,
- type="ai",
- content=response_text,
- prev_user_id=user_message.id,
- created_at=now_ts,
- updated_at=now_ts,
- is_deleted=0,
- )
- db.add(ai_message)
- db.commit()
- _refresh_conversation_snapshot(db, conv_id, user.user_id)
- db.commit()
- generated_title = data.exam_name
- if not generated_title:
- try:
- import json
- exam_data = json.loads(response_text)
- generated_title = exam_data.get("title") or exam_data.get("exam_name") or exam_data.get(
- "examTitle") or exam_data.get("试卷标题") or exam_data.get("标题") or ""
- except Exception:
- pass
- if generated_title:
- db.query(AIConversation).filter(AIConversation.id == conv_id).update(
- {"exam_name": generated_title,
- "updated_at": int(time.time())}
- )
- db.commit()
- except Exception as e:
- error_detail = str(e).strip() if str(
- e).strip() else f"未知错误({type(e).__name__})"
- logger.error(
- f"[send_deepseek_message] 考试工坊异常: {type(e).__name__}: {error_detail}")
- response_text = f"处理失败: {error_detail}"
- else:
- return {"statusCode": 400, "msg": f"不支持的业务类型: {data.business_type}"}
- return {
- "statusCode": 200,
- "msg": "success",
- "data": {
- "conversation_id": conv_id,
- "ai_conversation_id": conv_id,
- "response": response_text,
- "reply": response_text,
- "content": response_text,
- "message": response_text,
- "user_id": user.user_id,
- "business_type": data.business_type,
- },
- }
- except Exception as e:
- logger.error(f"[send_deepseek_message] 异常: {e}")
- return {"statusCode": 500, "msg": f"处理失败: {str(e)}"}
- @router.get("/get_history_record")
- async def get_history_record(
- request: Request,
- ai_conversation_id: int = 0,
- business_type: Optional[int] = None,
- db: Session = Depends(get_db),
- ):
- """兼容前端的历史记录查询:ai_conversation_id=0 返回对话列表,否则返回消息详情。"""
- user = request.state.user
- if not user:
- return {"statusCode": 401, "msg": "未授权"}
- if ai_conversation_id > 0:
- messages = (
- db.query(AIMessage)
- .filter(
- AIMessage.ai_conversation_id == ai_conversation_id,
- AIMessage.user_id == user.user_id,
- AIMessage.is_deleted == 0,
- )
- .order_by(AIMessage.id.asc())
- .all()
- )
- return {
- "statusCode": 200,
- "msg": "success",
- "total": len(messages),
- "data": [
- {
- "id": message.id,
- "ai_conversation_id": message.ai_conversation_id,
- "user_id": message.user_id,
- "type": message.type,
- "content": message.content,
- "user_feedback": message.user_feedback,
- "prev_user_id": message.prev_user_id,
- "search_source": message.search_source or "",
- "guess_you_want": message.guess_you_want or "",
- "created_at": _to_frontend_timestamp(message.created_at),
- "updated_at": _to_frontend_timestamp(message.updated_at),
- }
- for message in messages
- ],
- }
- conversations_query = db.query(AIConversation).filter(
- AIConversation.user_id == user.user_id,
- AIConversation.is_deleted == 0,
- )
- if business_type is not None:
- conversations_query = conversations_query.filter(
- AIConversation.business_type == business_type
- )
- total = conversations_query.count()
- conversations = (
- conversations_query
- .order_by(AIConversation.updated_at.desc(), AIConversation.id.desc())
- .limit(50)
- .all()
- )
- return {
- "statusCode": 200,
- "msg": "success",
- "total": total,
- "data": [
- {
- "id": conv.id,
- "title": _build_conversation_title(conv),
- "content": conv.content or "",
- "business_type": conv.business_type,
- "exam_name": conv.exam_name or "",
- "created_at": _to_frontend_timestamp(conv.created_at),
- "updated_at": _to_frontend_timestamp(conv.updated_at),
- }
- for conv in conversations
- ],
- }
- class DeleteConversationRequest(BaseModel):
- ai_conversation_id: int = 0
- ai_message_id: int = 0
- @router.post("/delete_conversation")
- async def delete_conversation(
- request: Request, data: DeleteConversationRequest, db: Session = Depends(get_db)
- ):
- """
- 删除对话(软删除)
- 同时软删除对话记录和所有关联的消息
- """
- user = request.state.user
- if not user:
- return {"statusCode": 401, "msg": "未授权"}
- now_ts = int(time.time())
- if data.ai_message_id:
- ai_message = (
- db.query(AIMessage)
- .filter(
- AIMessage.id == data.ai_message_id,
- AIMessage.user_id == user.user_id,
- AIMessage.type == "ai",
- AIMessage.is_deleted == 0,
- )
- .first()
- )
- if not ai_message:
- return {"statusCode": 404, "msg": "消息不存在"}
- db.query(AIMessage).filter(
- AIMessage.id == ai_message.id,
- AIMessage.user_id == user.user_id,
- ).update({"is_deleted": 1, "updated_at": now_ts})
- if ai_message.prev_user_id:
- db.query(AIMessage).filter(
- AIMessage.id == ai_message.prev_user_id,
- AIMessage.user_id == user.user_id,
- AIMessage.ai_conversation_id == ai_message.ai_conversation_id,
- ).update({"is_deleted": 1, "updated_at": now_ts})
- _refresh_conversation_snapshot(
- db, ai_message.ai_conversation_id, user.user_id)
- db.commit()
- return {"statusCode": 200, "msg": "删除成功"}
- if not data.ai_conversation_id:
- return {"statusCode": 400, "msg": "缺少删除参数"}
- db.query(AIConversation).filter(
- AIConversation.id == data.ai_conversation_id,
- AIConversation.user_id == user.user_id,
- ).update({"is_deleted": 1, "updated_at": now_ts})
- db.query(AIMessage).filter(
- AIMessage.ai_conversation_id == data.ai_conversation_id,
- AIMessage.user_id == user.user_id,
- ).update({"is_deleted": 1, "updated_at": now_ts})
- db.commit()
- return {"statusCode": 200, "msg": "删除成功"}
- class DeleteHistoryRequest(BaseModel):
- ai_conversation_id: int
- @router.post("/delete_history_record")
- async def delete_history_record(
- request: Request, data: DeleteHistoryRequest, db: Session = Depends(get_db)
- ):
- """删除历史记录(软删除)"""
- user = request.state.user
- if not user:
- return {"statusCode": 401, "msg": "未授权"}
- db.query(AIConversation).filter(
- AIConversation.id == data.ai_conversation_id,
- AIConversation.user_id == user.user_id,
- ).update({"is_deleted": 1, "updated_at": int(time.time())})
- db.commit()
- return {"statusCode": 200, "msg": "删除成功"}
- # ─────────────────────────────────────────────────────────────────────────
- # 流式接口 /stream/chat(无 DB,意图识别 + RAG)
- # ─────────────────────────────────────────────────────────────────────────
- class StreamChatRequest(BaseModel):
- message: str
- model: str = ""
- @router.post("/stream/chat")
- async def stream_chat(request: Request, data: StreamChatRequest):
- """流式聊天(SSE,不写 DB)"""
- message = data.message.strip()
- if not message:
- return JSONResponse(content={"statusCode": 400, "msg": "消息不能为空"})
- async def event_generator():
- intent_type = ""
- try:
- intent_result = await qwen_service.intent_recognition(message)
- if isinstance(intent_result, dict):
- intent_type = (
- intent_result.get("intent_type") or intent_result.get(
- "intent") or ""
- ).lower()
- except Exception as ie:
- logger.warning(f"[stream/chat] 意图识别异常: {ie}")
- rag_context = ""
- if intent_type in ("query_knowledge_base", "知识库查询", "技术咨询"):
- rag_context = await _rag_search(message)
- # 使用prompt加载器加载最终回答prompt
- system_content = load_prompt(
- "final_answer",
- userMessage=message,
- contextJSON=rag_context if rag_context else "暂无相关知识库内容"
- )
- messages = [
- {"role": "user", "content": system_content},
- ]
- try:
- buffer = ""
- pre_answer = ""
- thinking_buf = ""
- in_think = False
- thinking_done = False
- max_input_chars = getattr(
- settings.thinking_summary, "max_input_chars", 1500)
- async for chunk in qwen_service.stream_chat(messages):
- buffer += chunk
- while buffer:
- lower = buffer.lower()
- if not thinking_done:
- if not in_think:
- start_idx = lower.find("<think>")
- if start_idx == -1:
- yield f"data: {json.dumps({'content': buffer}, ensure_ascii=False)}\n\n"
- buffer = ""
- break
- pre_answer += buffer[:start_idx]
- buffer = buffer[start_idx + len("<think>"):]
- in_think = True
- continue
- end_idx = lower.find("</think>")
- if end_idx == -1:
- if max_input_chars and len(thinking_buf) < max_input_chars:
- thinking_buf += buffer[: max_input_chars -
- len(thinking_buf)]
- buffer = ""
- break
- if max_input_chars and len(thinking_buf) < max_input_chars:
- thinking_part = buffer[:end_idx]
- thinking_buf += thinking_part[: max_input_chars - len(
- thinking_buf)]
- buffer = buffer[end_idx + len("</think>"):]
- in_think = False
- thinking_done = True
- thinking_summary = await summarize_thinking_content(
- user_question=message,
- raw_thinking=thinking_buf,
- final_answer="",
- chat_service=qwen_service,
- context="stream_chat",
- )
- if thinking_summary:
- prefix = f"思考过程:\n{thinking_summary}\n\n回答:\n"
- yield f"data: {json.dumps({'content': prefix}, ensure_ascii=False)}\n\n"
- answer_chunk = (pre_answer + buffer).lstrip()
- if answer_chunk:
- yield f"data: {json.dumps({'content': answer_chunk}, ensure_ascii=False)}\n\n"
- pre_answer = ""
- buffer = ""
- break
- yield f"data: {json.dumps({'content': buffer}, ensure_ascii=False)}\n\n"
- buffer = ""
- # 流结束但未遇到 </think>:仅尝试生成要点(不回退输出 raw thinking)
- if in_think and not thinking_done and thinking_buf:
- thinking_summary = await summarize_thinking_content(
- user_question=message,
- raw_thinking=thinking_buf,
- final_answer="",
- chat_service=qwen_service,
- context="stream_chat_eof",
- )
- if thinking_summary:
- prefix = f"思考过程:\n{thinking_summary}\n\n回答:\n"
- yield f"data: {json.dumps({'content': prefix}, ensure_ascii=False)}\n\n"
- if pre_answer:
- yield f"data: {json.dumps({'content': pre_answer}, ensure_ascii=False)}\n\n"
- except Exception as e:
- logger.error(f"[stream/chat] 流式输出异常: {e}")
- yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
- finally:
- yield "data: [DONE]\n\n"
- return StreamingResponse(event_generator(), media_type="text/event-stream")
- # ─────────────────────────────────────────────────────────────────────────
- # 流式接口 /stream/chat-with-db(前端主聊天接口)
- # ─────────────────────────────────────────────────────────────────────────
- class StreamChatWithDBRequest(BaseModel):
- message: str
- ai_conversation_id: int = 0
- business_type: int = 0
- exam_name: str = ""
- ai_message_id: int = 0
- online_search_content: str = ""
- @router.post("/stream/chat-with-db")
- async def stream_chat_with_db(request: Request, data: StreamChatWithDBRequest):
- """
- 带 DB 操作的流式聊天(SSE)
- 流程:
- 1. 创建/获取对话
- 2. 插入用户消息和 AI 占位消息
- 3. 发送 initial 事件
- 4. RAG 检索
- 5. 构建历史上下文
- 6. 流式输出
- 7. 更新 AI 消息内容
- """
- user = request.state.user
- if not user:
- return JSONResponse(content={"statusCode": 401, "msg": "未授权"})
- message = data.message.strip()
- if not message:
- return JSONResponse(content={"statusCode": 400, "msg": "消息不能为空"})
- async def event_generator():
- db = SessionLocal()
- try:
- # 1. 创建或获取对话
- if data.ai_conversation_id == 0:
- conversation = AIConversation(
- user_id=user.user_id,
- content=_build_conversation_preview(message, limit=100),
- business_type=data.business_type,
- exam_name=data.exam_name,
- created_at=int(time.time()),
- updated_at=int(time.time()),
- is_deleted=0,
- )
- db.add(conversation)
- db.commit()
- db.refresh(conversation)
- conv_id = conversation.id
- else:
- existing_conversation = (
- db.query(AIConversation)
- .filter(
- AIConversation.id == data.ai_conversation_id,
- AIConversation.user_id == user.user_id,
- AIConversation.is_deleted == 0,
- )
- .first()
- )
- if existing_conversation:
- conv_id = existing_conversation.id
- db.query(AIConversation).filter(
- AIConversation.id == conv_id,
- AIConversation.user_id == user.user_id,
- ).update(
- {
- "content": _build_conversation_preview(message, limit=100),
- "business_type": data.business_type,
- "exam_name": data.exam_name if data.business_type == 3 else "",
- "updated_at": int(time.time()),
- }
- )
- db.commit()
- else:
- conversation = AIConversation(
- user_id=user.user_id,
- content=_build_conversation_preview(
- message, limit=100),
- business_type=data.business_type,
- exam_name=data.exam_name if data.business_type == 3 else "",
- created_at=int(time.time()),
- updated_at=int(time.time()),
- is_deleted=0,
- )
- db.add(conversation)
- db.commit()
- db.refresh(conversation)
- conv_id = conversation.id
- # 2. 插入用户消息
- user_msg = AIMessage(
- ai_conversation_id=conv_id,
- user_id=user.user_id,
- type="user",
- content=message,
- created_at=int(time.time()),
- updated_at=int(time.time()),
- is_deleted=0,
- )
- db.add(user_msg)
- db.commit()
- db.refresh(user_msg)
- # 3. 插入 AI 占位消息
- ai_msg = AIMessage(
- ai_conversation_id=conv_id,
- user_id=user.user_id,
- type="ai",
- content="",
- prev_user_id=user_msg.id,
- created_at=int(time.time()),
- updated_at=int(time.time()),
- is_deleted=0,
- )
- db.add(ai_msg)
- db.commit()
- db.refresh(ai_msg)
- # 4. 发送 initial 事件
- yield f"data: {json.dumps({'type': 'initial', 'ai_conversation_id': conv_id, 'ai_message_id': ai_msg.id}, ensure_ascii=False)}\n\n"
- # 5. RAG search
- rag_context = await _rag_search(message, top_k=10)
- if data.business_type in (1, 2):
- # PPT outline / AI writing: use dedicated prompt
- prompt_name = "ppt_outline" if data.business_type == 1 else "document_writing"
- system_content = load_prompt(
- prompt_name,
- userMessage=message,
- contextJSON=rag_context if rag_context else "?????????"
- )
- messages = [
- {"role": "user", "content": system_content},
- ]
- else:
- # 6. History context (last 4 items, 2 turns)
- history_msgs = (
- db.query(AIMessage)
- .filter(
- AIMessage.ai_conversation_id == conv_id,
- AIMessage.id < ai_msg.id,
- AIMessage.is_deleted == 0,
- )
- .order_by(AIMessage.updated_at.desc())
- .limit(4)
- .all()
- )
- history_msgs.reverse()
- history_context = ""
- for msg in history_msgs:
- role = "??" if msg.type == "user" else "??"
- history_context += f"{role}: {msg.content}\n\n"
- # 7. Build final prompt
- context_parts = []
- if rag_context:
- context_parts.append(f"??????\n{rag_context}")
- if data.online_search_content:
- context_parts.append(
- f"???????\n{data.online_search_content}")
- context_json = "\n\n".join(
- context_parts) if context_parts else "?????????"
- system_content = load_prompt(
- "final_answer",
- userMessage=message,
- contextJSON=context_json,
- historyContext=history_context if history_context else ""
- )
- messages = [
- {"role": "user", "content": system_content},
- ]
- # 8. 流式输出并收集完整回复
- full_response = ""
- try:
- summary_enabled = getattr(
- settings.thinking_summary, "enabled", True)
- max_input_chars = getattr(
- settings.thinking_summary, "max_input_chars", 1500)
- buffer = ""
- pre_answer = ""
- thinking_buf = ""
- in_think = False
- thinking_done = False
- async for chunk in qwen_service.stream_chat(messages):
- if not summary_enabled:
- escaped_chunk = chunk.replace("\n", "\\n")
- full_response += chunk
- yield f"data: {escaped_chunk}\n\n"
- continue
- buffer += chunk
- while buffer:
- lower = buffer.lower()
- if not thinking_done:
- if not in_think:
- start_idx = lower.find("<think>")
- if start_idx == -1:
- escaped_text = buffer.replace("\n", "\\n")
- full_response += buffer
- yield f"data: {escaped_text}\n\n"
- buffer = ""
- break
- pre_answer += buffer[:start_idx]
- buffer = buffer[start_idx + len("<think>"):]
- in_think = True
- continue
- end_idx = lower.find("</think>")
- if end_idx == -1:
- if max_input_chars and len(thinking_buf) < max_input_chars:
- thinking_buf += buffer[: max_input_chars -
- len(thinking_buf)]
- buffer = ""
- break
- if max_input_chars and len(thinking_buf) < max_input_chars:
- thinking_part = buffer[:end_idx]
- thinking_buf += thinking_part[: max_input_chars - len(
- thinking_buf)]
- buffer = buffer[end_idx + len("</think>"):]
- in_think = False
- thinking_done = True
- thinking_summary = await summarize_thinking_content(
- user_question=message,
- raw_thinking=thinking_buf,
- final_answer="",
- chat_service=qwen_service,
- context="stream_chat_with_db",
- )
- if thinking_summary:
- prefix = f"思考过程:\n{thinking_summary}\n\n回答:\n"
- full_response += prefix
- escaped_prefix = prefix.replace('\n', '\\n')
- yield f"data: {escaped_prefix}\n\n"
- answer_chunk = (pre_answer + buffer).lstrip()
- if answer_chunk:
- full_response += answer_chunk
- escaped_answer = answer_chunk.replace(
- '\n', '\\n')
- yield f"data: {escaped_answer}\n\n"
- pre_answer = ""
- buffer = ""
- break
- escaped_text = buffer.replace("\n", "\\n")
- full_response += buffer
- yield f"data: {escaped_text}\n\n"
- buffer = ""
- except Exception as e:
- logger.error(f"[stream/chat-with-db] 流式输出异常: {e}")
- yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
- # 流结束但未遇到 </think>:仅尝试生成要点(不回退输出 raw thinking)
- if summary_enabled and in_think and not thinking_done and thinking_buf:
- thinking_summary = await summarize_thinking_content(
- user_question=message,
- raw_thinking=thinking_buf,
- final_answer="",
- chat_service=qwen_service,
- context="stream_chat_with_db_eof",
- )
- if thinking_summary:
- prefix = f"思考过程:\n{thinking_summary}\n\n回答:\n"
- full_response += prefix
- escaped_prefix = prefix.replace('\n', '\\n')
- yield f"data: {escaped_prefix}\n\n"
- if pre_answer:
- full_response += pre_answer
- escaped_pre_answer = pre_answer.replace('\n', '\\n')
- yield f"data: {escaped_pre_answer}\n\n"
- # 9. 更新 AI 消息内容
- if full_response:
- now_ts = int(time.time())
- db.query(AIMessage).filter(AIMessage.id == ai_msg.id).update(
- {"content": full_response, "updated_at": now_ts}
- )
- db.query(AIConversation).filter(
- AIConversation.id == conv_id,
- AIConversation.user_id == user.user_id,
- ).update(
- {
- "content": _build_conversation_preview(message, limit=100),
- "business_type": data.business_type,
- "exam_name": data.exam_name if data.business_type == 3 else "",
- "updated_at": now_ts,
- }
- )
- db.commit()
- # 10. 结束标记
- yield "data: [DONE]\n\n"
- except Exception as e:
- logger.error(f"[stream/chat-with-db] 处理异常: {e}")
- yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
- finally:
- db.close()
- return StreamingResponse(event_generator(), media_type="text/event-stream")
- # ─────────────────────────────────────────────────────────────────────────
- # 猜你想问
- # ─────────────────────────────────────────────────────────────────────────
- class GuessYouWantRequest(BaseModel):
- ai_message_id: int
- @router.post("/guess_you_want")
- async def guess_you_want(
- request: Request,
- data: GuessYouWantRequest,
- db: Session = Depends(get_db),
- ):
- """生成"猜你想问"的3个关联问题,保存到 AIMessage.guess_you_want"""
- user = request.state.user
- if not user:
- return {"statusCode": 401, "msg": "未授权"}
- try:
- ai_msg = (
- db.query(AIMessage)
- .filter(AIMessage.id == data.ai_message_id, AIMessage.is_deleted == 0)
- .first()
- )
- if not ai_msg:
- return {"statusCode": 404, "msg": "消息不存在"}
- # 使用prompt加载器加载猜你想问prompt
- system_content = load_prompt(
- "guess_questions",
- currentContent=ai_msg.content[:500]
- )
- messages = [
- {"role": "user", "content": system_content},
- ]
- response = await qwen_service.chat(messages)
- try:
- # 尝试从响应中提取 JSON
- json_match = re.search(
- r'\{[^{}]*"questions"[^{}]*\}', response, re.DOTALL)
- if json_match:
- response_json = json.loads(json_match.group())
- else:
- response_json = json.loads(response)
- questions = response_json.get("questions", [])
- except Exception:
- lines = [l.strip() for l in response.split("\n") if l.strip()]
- questions = []
- for line in lines:
- clean = line.lstrip("0123456789.-、 ").strip()
- if clean and len(clean) > 5:
- questions.append(clean)
- if not questions:
- questions = ["该话题的具体应用场景?", "有哪些注意事项?", "相关案例分析?"]
- questions = _finalize_related_questions(
- questions, ai_msg.content, limit=3)
- guess_json = json.dumps({"questions": questions}, ensure_ascii=False)
- db.query(AIMessage).filter(AIMessage.id == data.ai_message_id).update(
- {"guess_you_want": guess_json, "updated_at": int(time.time())}
- )
- db.commit()
- return {
- "statusCode": 200,
- "msg": "success",
- "data": {"ai_message_id": data.ai_message_id, "questions": questions},
- }
- except Exception as e:
- logger.error(f"[guess_you_want] 处理异常: {e}")
- return {"statusCode": 500, "msg": f"处理失败: {str(e)}"}
- # ─────────────────────────────────────────────────────────────────────────
- # 在线搜索(Dify 工作流集成)
- # ─────────────────────────────────────────────────────────────────────────
- @router.get("/online_search")
- async def online_search(question: str, request: Request, db: Session = Depends(get_db)):
- """
- 在线搜索
- 流程:Qwen 提炼关键词 → Dify 工作流 → 返回摘要
- """
- user = request.state.user
- if not user:
- return {"statusCode": 401, "msg": "未授权"}
- try:
- keywords = await qwen_service.extract_keywords(question)
- dify_config = getattr(settings, "dify", None)
- if not dify_config or not getattr(dify_config, "workflow_url", None):
- return {"statusCode": 500, "msg": "Dify 配置未设置"}
- headers = {
- "Authorization": f"Bearer {dify_config.auth_token}",
- "Content-Type": "application/json",
- }
- payload = {
- "workflow_id": dify_config.workflow_id,
- "inputs": {
- "keywords": keywords,
- "num": 5, # 搜索结果数量
- "max_text_len": 4000 # 最大文本长度
- },
- "response_mode": "blocking",
- "user": getattr(user, "account", str(user.user_id)),
- }
- async with httpx.AsyncClient(timeout=30.0) as client:
- resp = await client.post(dify_config.workflow_url, headers=headers, json=payload)
- if resp.status_code != 200:
- logger.error(
- f"[online_search] Dify 调用失败: {resp.status_code}, 响应: {resp.text}")
- return {"statusCode": 500, "msg": f"搜索服务异常: {resp.status_code}"}
- result = resp.json()
- search_text = result.get("data", {}).get(
- "outputs", {}).get("text", "")
- return {
- "statusCode": 200,
- "msg": "success",
- "data": {"keywords": keywords, "result": search_text},
- }
- except Exception as e:
- logger.error(f"[online_search] 处理异常: {e}")
- return {"statusCode": 500, "msg": f"搜索失败: {str(e)}"}
- class SaveOnlineSearchResultRequest(BaseModel):
- ai_message_id: int
- search_result: str
- @router.post("/save_online_search_result")
- async def save_online_search_result(
- request: Request,
- data: SaveOnlineSearchResultRequest,
- db: Session = Depends(get_db),
- ):
- """保存联网搜索结果到 AIMessage.search_source"""
- user = request.state.user
- if not user:
- return {"statusCode": 401, "msg": "未授权"}
- try:
- db.query(AIMessage).filter(AIMessage.id == data.ai_message_id).update(
- {"search_source": data.search_result,
- "updated_at": int(time.time())}
- )
- db.commit()
- return {"statusCode": 200, "msg": "保存成功"}
- except Exception as e:
- logger.error(f"[save_online_search_result] 处理异常: {e}")
- return {"statusCode": 500, "msg": f"保存失败: {str(e)}"}
- # ─────────────────────────────────────────────────────────────────────────
- # 意图识别独立接口
- # ─────────────────────────────────────────────────────────────────────────
- class IntentRecognitionRequest(BaseModel):
- message: str
- save_to_db: bool = False
- ai_conversation_id: int = 0
- @router.post("/intent_recognition")
- async def intent_recognition(
- request: Request,
- data: IntentRecognitionRequest,
- db: Session = Depends(get_db),
- ):
- """独立意图识别接口;若为 greeting/faq 且 save_to_db=True 则直接存 DB"""
- user = request.state.user
- if not user:
- return {"statusCode": 401, "msg": "未授权"}
- try:
- intent_result = await qwen_service.intent_recognition(data.message)
- intent_type = ""
- response_text = ""
- if isinstance(intent_result, dict):
- intent_type = (
- intent_result.get("intent_type") or intent_result.get(
- "intent") or ""
- ).lower()
- response_text = intent_result.get("response", "")
- if data.save_to_db and intent_type in ("greeting", "问候", "faq", "常见问题"):
- if data.ai_conversation_id == 0:
- conversation = AIConversation(
- user_id=user.user_id,
- content=data.message[:100],
- business_type=0,
- created_at=int(time.time()),
- updated_at=int(time.time()),
- is_deleted=0,
- )
- db.add(conversation)
- db.commit()
- db.refresh(conversation)
- conv_id = conversation.id
- else:
- conv_id = data.ai_conversation_id
- user_msg = AIMessage(
- ai_conversation_id=conv_id,
- user_id=user.user_id,
- type="user",
- content=data.message,
- created_at=int(time.time()),
- updated_at=int(time.time()),
- is_deleted=0,
- )
- db.add(user_msg)
- db.commit()
- ai_msg = AIMessage(
- ai_conversation_id=conv_id,
- user_id=user.user_id,
- type="ai",
- content=response_text,
- prev_user_id=user_msg.id,
- created_at=int(time.time()),
- updated_at=int(time.time()),
- is_deleted=0,
- )
- db.add(ai_msg)
- db.commit()
- db.refresh(ai_msg)
- return {
- "statusCode": 200,
- "msg": "success",
- "data": {
- "intent_type": intent_type,
- "response": response_text,
- "ai_conversation_id": conv_id,
- "ai_message_id": ai_msg.id,
- "saved_to_db": True,
- },
- }
- return {
- "statusCode": 200,
- "msg": "success",
- "data": {
- "intent_type": intent_type,
- "response": response_text,
- "saved_to_db": False,
- },
- }
- except Exception as e:
- logger.error(f"[intent_recognition] 处理异常: {e}")
- return {"statusCode": 500, "msg": f"处理失败: {str(e)}"}
- # ─────────────────────────────────────────────────────────────────────────
- # 获取用户推荐问题(模糊查询 QA / RecommendQuestion 表)
- # ─────────────────────────────────────────────────────────────────────────
- @router.get("/get_user_recommend_question")
- async def get_user_recommend_question(
- keyword: str = "",
- limit: int = 10,
- db: Session = Depends(get_db),
- ):
- """获取推荐问题(支持模糊查询)"""
- try:
- query = db.query(RecommendQuestion).filter(
- RecommendQuestion.is_deleted == 0)
- if keyword:
- query = query.filter(
- RecommendQuestion.question.like(f"%{keyword}%"))
- questions = query.order_by(
- RecommendQuestion.id.desc()).limit(limit).all()
- return {
- "statusCode": 200,
- "msg": "success",
- "data": [
- {"id": q.id, "question": q.question, "created_at": q.created_at}
- for q in questions
- ],
- }
- except Exception as e:
- logger.error(f"[get_user_recommend_question] 处理异常: {e}")
- return {"statusCode": 500, "msg": f"查询失败: {str(e)}"}
- # ─────────────────────────────────────────────────────────────────────────
- # PPT 大纲 / 文档编辑保存
- # ─────────────────────────────────────────────────────────────────────────
- class SavePPTOutlineRequest(BaseModel):
- ai_message_id: int
- content: str
- @router.post("/save_ppt_outline")
- async def save_ppt_outline(
- request: Request,
- data: SavePPTOutlineRequest,
- db: Session = Depends(get_db),
- ):
- """更新 AIMessage.content 保存 PPT 大纲内容"""
- user = request.state.user
- if not user:
- return {"statusCode": 401, "msg": "未授权"}
- try:
- db.query(AIMessage).filter(AIMessage.id == data.ai_message_id).update(
- {"content": data.content, "updated_at": int(time.time())}
- )
- db.commit()
- return {"statusCode": 200, "msg": "保存成功"}
- except Exception as e:
- logger.error(f"[save_ppt_outline] 处理异常: {e}")
- return {"statusCode": 500, "msg": f"保存失败: {str(e)}"}
- class SaveEditDocumentRequest(BaseModel):
- ai_message_id: int
- content: str
- @router.post("/save_edit_document")
- async def save_edit_document(
- request: Request,
- data: SaveEditDocumentRequest,
- db: Session = Depends(get_db),
- ):
- """更新 ai 类型 AIMessage.content(AI写作编辑保存)"""
- user = request.state.user
- if not user:
- return {"statusCode": 401, "msg": "未授权"}
- try:
- db.query(AIMessage).filter(
- AIMessage.id == data.ai_message_id,
- AIMessage.type == "ai",
- ).update({"content": data.content, "updated_at": int(time.time())})
- db.commit()
- return {"statusCode": 200, "msg": "保存成功"}
- except Exception as e:
- logger.error(f"[save_edit_document] 处理异常: {e}")
- return {"statusCode": 500, "msg": f"保存失败: {str(e)}"}
|