from fastapi import APIRouter, Depends, Request
from fastapi.responses import StreamingResponse, JSONResponse
from sqlalchemy.orm import Session
from pydantic import BaseModel
from typing import Optional
from database import get_db, SessionLocal
from models.chat import AIConversation, AIMessage
from models.total import RecommendQuestion
from utils.config import settings
from utils.logger import logger
from services.qwen_service import qwen_service
from services.deepseek_service import deepseek_service
from utils.prompt_loader import load_prompt
from utils.thinking_summary import split_thinking_and_answer, summarize_thinking_content
import time
import json
import httpx
import re
router = APIRouter()
def _build_conversation_preview(content: str, limit: int = 50) -> str:
content = (content or "").strip()
if len(content) <= limit:
return content
return content[:limit] + "..."
def _to_frontend_timestamp(timestamp: Optional[int]) -> Optional[int]:
if not timestamp:
return None
return timestamp if timestamp >= 10**12 else timestamp * 1000
def _build_conversation_title(conversation: AIConversation) -> str:
if conversation.business_type == 3 and (conversation.exam_name or "").strip():
return conversation.exam_name.strip()
return _build_conversation_preview(conversation.content or "", limit=30)
def _extract_json_object_from_index(source: str, start_idx: int) -> str:
if start_idx < 0 or start_idx >= len(source) or source[start_idx] != "{":
return ""
depth = 0
in_string = False
escaped = False
for idx in range(start_idx, len(source)):
ch = source[idx]
if escaped:
escaped = False
continue
if in_string:
if ch == "\\":
escaped = True
elif ch == '"':
in_string = False
continue
if ch == '"':
in_string = True
continue
if ch == "{":
depth += 1
elif ch == "}":
depth -= 1
if depth == 0:
return source[start_idx: idx + 1]
return ""
def _extract_balanced_json_objects(text: str) -> list[str]:
source = (text or "").strip()
if not source:
return []
objects = []
seen = set()
for idx, ch in enumerate(source):
if ch != "{":
continue
candidate = _extract_json_object_from_index(source, idx)
if candidate and candidate not in seen:
objects.append(candidate)
seen.add(candidate)
return objects
def _extract_trailing_json_candidates(text: str) -> list[str]:
source = (text or "").strip()
if not source:
return []
candidates = []
seen = set()
line_start_indexes = [
match.start()
for match in re.finditer(r"(?m)^[ \t]*\{", source)
]
for start_idx in reversed(line_start_indexes):
candidate = source[start_idx:].strip()
if candidate and candidate not in seen:
candidates.append(candidate)
seen.add(candidate)
return candidates
def _extract_explicit_answer_segment(text: str) -> str:
source = (text or "").strip()
if not source:
return ""
markers = (
"final answer:",
"final output:",
"answer:",
"output:",
"json:",
)
lowered = source.lower()
for marker in markers:
idx = lowered.rfind(marker)
if idx >= 0:
candidate = source[idx + len(marker):].strip()
if candidate:
return candidate
return ""
def _extract_brace_sliced_candidates(text: str) -> list[str]:
source = (text or "").strip()
if not source:
return []
candidates = []
seen = set()
first_brace = source.find("{")
last_brace = source.rfind("}")
if first_brace >= 0 and last_brace > first_brace:
candidate = source[first_brace:last_brace + 1].strip()
if candidate and candidate not in seen:
candidates.append(candidate)
seen.add(candidate)
return candidates
def _looks_like_exam_payload(payload: object) -> bool:
if not isinstance(payload, dict):
return False
questions = payload.get("questions")
return any(
key in payload
for key in (
"singleChoice",
"single_choice",
"单选题",
"judge",
"判断题",
"multiple",
"multiple_choice",
"multipleChoice",
"多选题",
"short",
"short_answer",
"shortAnswer",
"简答题",
)
) or (
isinstance(questions, dict)
and any(
key in questions
for key in (
"singleChoice",
"single_choice",
"单选题",
"judge",
"判断题",
"multiple",
"multiple_choice",
"multipleChoice",
"多选题",
"short",
"short_answer",
"shortAnswer",
"简答题",
)
)
)
def _score_exam_payload_candidate(payload: object) -> int:
if not isinstance(payload, dict):
return 0
score = 0
questions = payload.get("questions") if isinstance(
payload.get("questions"), dict) else {}
strong_keys = (
"singleChoice",
"single_choice",
"单选题",
"judge",
"判断题",
"multiple",
"multiple_choice",
"multipleChoice",
"多选题",
"short",
"short_answer",
"shortAnswer",
"简答题",
)
weak_keys = (
"title",
"exam_name",
"examTitle",
"试卷标题",
"总分",
"totalScore",
"totalQuestions",
)
score += sum(10 for key in strong_keys if key in payload)
score += sum(8 for key in strong_keys if key in questions)
score += sum(2 for key in weak_keys if key in payload)
section_candidates = []
for _, value in payload.items():
if isinstance(value, dict):
section_candidates.append(value)
section_candidates.extend(
value for value in questions.values() if isinstance(value, dict))
for section in section_candidates:
if "questions" in section and isinstance(section.get("questions"), list):
score += 6
question_list = section.get("questions") or []
if question_list and isinstance(question_list[0], dict):
first_question = question_list[0]
if any(k in first_question for k in ("text", "question_text", "question", "title", "content", "题干", "题目")):
score += 4
if "options" in first_question:
score += 3
if any(k in first_question for k in ("answer", "answers", "correct_answer", "correct_answers", "答案", "正确答案")):
score += 3
if any(k in first_question for k in ("analysis", "explanation", "解析")):
score += 2
if any(k in section for k in ("count", "question_count", "数量")):
score += 2
if any(k in section for k in ("scorePerQuestion", "score_per_question", "每题分值")):
score += 1
return score
def _escape_inner_quotes_in_json(text: str) -> str:
chars = []
in_string = False
escaped = False
for idx, ch in enumerate(text):
if not in_string:
chars.append(ch)
if ch == '"':
in_string = True
escaped = False
continue
if escaped:
chars.append(ch)
escaped = False
continue
if ch == "\\":
chars.append(ch)
escaped = True
continue
if ch == '"':
next_non_space = ""
for next_idx in range(idx + 1, len(text)):
if not text[next_idx].isspace():
next_non_space = text[next_idx]
break
if next_non_space in {",", "}", "]", ":"}:
chars.append(ch)
in_string = False
else:
chars.append('\\"')
continue
chars.append(ch)
return "".join(chars)
def _try_parse_exam_json(candidate: str) -> Optional[dict]:
text = (candidate or "").strip()
if not text:
return None
text = (
text.replace("\ufeff", "")
.replace("```json", "")
.replace("```JSON", "")
.replace("```", "")
.replace("“", '"')
.replace("”", '"')
).strip()
try:
parsed = json.loads(text)
except Exception:
repaired_text = _escape_inner_quotes_in_json(text)
repaired_text = re.sub(r",\s*([}\]])", r"\1", repaired_text)
try:
parsed = json.loads(repaired_text)
except Exception:
return None
return parsed if _looks_like_exam_payload(parsed) else None
def _sanitize_exam_response(raw_response: str) -> str:
"""考试工坊只向前端/数据库透传可 JSON.parse 的试卷 JSON。"""
raw_text = (raw_response or "").strip()
if not raw_text:
return ""
_, answer = split_thinking_and_answer(raw_text)
explicit_answer = _extract_explicit_answer_segment(raw_text)
for candidate in (answer, explicit_answer, raw_text):
parsed = _try_parse_exam_json(candidate)
if parsed:
return json.dumps(parsed, ensure_ascii=False)
parsed_candidates = []
for candidate in _extract_balanced_json_objects(raw_text):
parsed = _try_parse_exam_json(candidate)
if parsed:
parsed_candidates.append((parsed, candidate))
for candidate in _extract_trailing_json_candidates(raw_text):
parsed = _try_parse_exam_json(candidate)
if parsed:
parsed_candidates.append((parsed, candidate))
for candidate in _extract_brace_sliced_candidates(raw_text):
parsed = _try_parse_exam_json(candidate)
if parsed:
parsed_candidates.append((parsed, candidate))
if parsed_candidates:
parsed_candidates.sort(
key=lambda item: (
_score_exam_payload_candidate(item[0]),
len(json.dumps(item[0], ensure_ascii=False)),
),
reverse=True,
)
best_payload, best_raw_candidate = parsed_candidates[0]
if _score_exam_payload_candidate(best_payload) > 0:
return json.dumps(best_payload, ensure_ascii=False)
logger.warning(
"[exam] 已提取到JSON对象但试卷特征较弱,选择最大候选兜底: score=%s snippet=%s",
_score_exam_payload_candidate(best_payload),
(best_raw_candidate or "")[:200],
)
return json.dumps(best_payload, ensure_ascii=False)
logger.warning("[exam] 未能从模型响应中提取试卷 JSON,保留原始响应供前端兜底解析")
return raw_text
def _normalize_related_question(question: str) -> str:
if not isinstance(question, str):
return ""
text = question.strip().strip('"').strip("'")
text = re.sub(r"^[0-9]+[\.\)\]、]\s*", "", text)
text = re.sub(r"^[-*]\s*", "", text)
return text.strip()
def _is_placeholder_related_question(question: str) -> bool:
normalized = _normalize_related_question(question).lower()
if not normalized:
return True
placeholder_patterns = (
r"^q\s*\d+$",
r"^question\s*\d+$",
r"^questions?\s*\d+$",
r"^问题\s*\d+$",
r"^相关问题\s*\d+$",
r"^推荐问题\s*\d+$",
r"^更多相关问题$",
r"^更多问题$",
)
return any(re.fullmatch(pattern, normalized) for pattern in placeholder_patterns)
def _contains_chinese(text: str) -> bool:
return any("\u4e00" <= char <= "\u9fff" for char in text or "")
def _is_invalid_related_question(question: str) -> bool:
normalized = _normalize_related_question(question)
if (
not normalized
or len(normalized) < 4
or _is_placeholder_related_question(normalized)
or not _contains_chinese(normalized)
):
return True
lowered = normalized.lower()
blocked_keywords = (
"thinking process",
"analyze the request",
"role:",
"**role",
"professional question recommendation",
"infrastructure construction technology",
"output format",
"json",
"prompt",
"system",
"assistant",
"角色定义",
"任务目标",
"输入内容",
"生成要求",
"输出格式",
"开始生成",
)
return any(keyword in lowered for keyword in blocked_keywords)
def _extract_related_question_topic(content: str) -> str:
if not content:
return "当前话题"
text = re.sub(r"<[^>]+>", " ", str(content))
text = re.sub(r"\s+", " ", text).strip()
text = re.sub(
r"^(好的[!!,, ]*|我理解您提出的问题[,, ]*|这个问题[,, ]*|总的来说[::,, ]*)+",
"",
text,
)
pattern = re.search(
r"(?:主要围绕|围绕|关于|针对|聚焦)([^。!?\n,,;;]{4,32})",
text,
)
if pattern:
topic = pattern.group(1).strip("“”\"' ::,,")
if topic:
return topic
sentence = re.split(r"[。!?\n]", text, maxsplit=1)[0].strip("“”\"' ::,,")
if sentence:
return sentence[:24]
return "当前话题"
def _build_related_question_fallbacks(content: str) -> list[str]:
topic = _extract_related_question_topic(content)
return [
f"{topic}在现场实施时需要重点关注哪些风险点?",
f"{topic}相关的方案编制、审批和验收要求有哪些?",
f"针对{topic},日常检查和监测应抓住哪些关键指标?",
]
def _finalize_related_questions(questions: list, content: str, limit: int = 3) -> list[str]:
cleaned_questions = []
seen = set()
for question in questions or []:
normalized = _normalize_related_question(question)
lowered = normalized.lower()
if (
_is_invalid_related_question(normalized)
or lowered in seen
):
continue
cleaned_questions.append(normalized)
seen.add(lowered)
if len(cleaned_questions) == limit:
return cleaned_questions
for fallback in _build_related_question_fallbacks(content):
lowered = fallback.lower()
if lowered in seen:
continue
cleaned_questions.append(fallback)
seen.add(lowered)
if len(cleaned_questions) == limit:
break
return cleaned_questions[:limit]
def _refresh_conversation_snapshot(db: Session, conversation_id: int, user_id: int) -> None:
latest_message = (
db.query(AIMessage)
.filter(
AIMessage.ai_conversation_id == conversation_id,
AIMessage.user_id == user_id,
AIMessage.is_deleted == 0,
)
.order_by(AIMessage.id.desc())
.first()
)
if not latest_message:
db.query(AIConversation).filter(
AIConversation.id == conversation_id,
AIConversation.user_id == user_id,
).update({"is_deleted": 1, "updated_at": int(time.time())})
return
latest_user_message = (
db.query(AIMessage)
.filter(
AIMessage.ai_conversation_id == conversation_id,
AIMessage.user_id == user_id,
AIMessage.type == "user",
AIMessage.is_deleted == 0,
)
.order_by(AIMessage.id.desc())
.first()
)
preview_source = (
latest_user_message.content
if latest_user_message and latest_user_message.content
else latest_message.content
)
preview_content = _build_conversation_preview(
preview_source or "", limit=100)
db.query(AIConversation).filter(
AIConversation.id == conversation_id,
AIConversation.user_id == user_id,
).update(
{
"content": preview_content or " ",
"updated_at": int(time.time()),
}
)
# ─────────────────────────────────────────────────────────────────────────
# 辅助函数
# ─────────────────────────────────────────────────────────────────────────
async def _rag_search(message: str, top_k: int = 5) -> str:
"""调用 search API 做 RAG 检索,返回上下文文本"""
try:
search_cfg = getattr(settings, 'search', None)
if not search_cfg or not hasattr(search_cfg, 'api_url'):
return ""
search_url = search_cfg.api_url
if not search_url:
return ""
async with httpx.AsyncClient(timeout=10.0) as client:
resp = await client.post(
search_url,
json={"query": message, "n_results": top_k},
)
if resp.status_code == 200:
data = resp.json()
docs = data.get("results") or data.get("documents") or []
return "\n\n".join(
d.get("content") or d.get("text") or str(d)
for d in docs[:top_k]
if d.get("content") or d.get("text")
)
except Exception as e:
logger.warning(f"[RAG] 检索失败(可忽略): {e}")
return ""
SAFETY_TRAINING_PLAN_SYSTEM_PROMPT = """
你是安全培训需求整理助手。请把用户的自然语言输入整理成安全培训PPT大纲生成任务。
规则:
1. 只输出一个 JSON 对象,不要输出 Markdown、解释或额外文字。
2. 即使用户说“通知”“材料”“文档”,也必须理解为安全培训模块中的 PPT 大纲需求,不要切换到其他文档生成任务。
3. 如果字段缺失,请根据安全培训场景合理补全,但不要编造具体制度编号、人员姓名或不存在的事实。
4. template 字段用于选择大纲模板,默认填“标准安全培训PPT大纲”。
5. content_focus 至少给出 3 个要点。
JSON 字段:
{
"topic": "培训主题",
"template": "模板名称",
"content_focus": ["内容要点1", "内容要点2", "内容要点3"],
"audience": "参训对象",
"time": "培训时间",
"location": "培训地点",
"goal": "培训目标",
"notes": "其他要求",
"normalized_request": "归一化后的安全培训PPT大纲生成需求"
}
"""
def _extract_tag_value(message: str, tag: str) -> str:
match = re.search(fr"<{tag}>(.*?){tag}>", message or "", re.DOTALL)
return match.group(1).strip() if match else ""
def _strip_document_tags(message: str) -> str:
text = message or ""
for tag in ("word", "filename", "filesize"):
text = re.sub(fr"<{tag}>.*?{tag}>", " ", text, flags=re.DOTALL)
return re.sub(r"\s+", " ", text).strip()
def _extract_safety_training_request_payload(message: str) -> dict:
return {
"document_content": _extract_tag_value(message, "word"),
"filename": _extract_tag_value(message, "filename"),
"filesize": _extract_tag_value(message, "filesize"),
"request": _strip_document_tags(message),
}
def _clean_safety_training_topic(message: str) -> str:
request_text = _extract_safety_training_request_payload(message)["request"]
first_clause = re.split(r"[,。;;,\n]", request_text, maxsplit=1)[0].strip()
topic = first_clause or request_text or "安全培训"
for token in ("请", "帮我", "帮忙", "生成", "制作", "输出", "一份", "一个", "一下", "PPT大纲", "ppt大纲", "大纲", "通知", "文档", "材料"):
topic = topic.replace(token, "")
topic = re.sub(r"\s+", "", topic).strip(" ::,,。;;")
if not topic:
topic = "安全培训"
if "培训" not in topic:
topic = f"{topic}安全培训"
return topic
def _parse_json_object(text: str) -> dict:
if not text:
return {}
cleaned = re.sub(r"```(?:json)?\s*", "", str(text)
).replace("```", "").strip()
match = re.search(r"\{.*\}", cleaned, re.DOTALL)
if not match:
return {}
try:
parsed = json.loads(match.group(0))
return parsed if isinstance(parsed, dict) else {}
except json.JSONDecodeError:
return {}
def _build_fallback_safety_training_plan(message: str) -> dict:
topic = _clean_safety_training_topic(message)
payload = _extract_safety_training_request_payload(message)
return {
"topic": topic,
"template": "标准安全培训PPT大纲",
"content_focus": ["安全生产责任", "现场风险识别", "安全意识提升", "培训纪律与行为规范"],
"audience": "参训员工",
"time": "",
"location": "",
"goal": "提升参训人员安全意识和施工现场风险防控能力",
"notes": payload["request"],
"normalized_request": f"围绕{topic}生成安全培训PPT大纲",
}
def _normalize_safety_training_plan(message: str, raw_plan: dict) -> dict:
plan = _build_fallback_safety_training_plan(message)
if not isinstance(raw_plan, dict):
return plan
for key in ("topic", "template", "audience", "time", "location", "goal", "notes", "normalized_request"):
value = raw_plan.get(key)
if isinstance(value, str) and value.strip():
plan[key] = value.strip()
focus = raw_plan.get("content_focus")
if isinstance(focus, list):
normalized_focus = [str(item).strip()
for item in focus if str(item).strip()]
if normalized_focus:
plan["content_focus"] = normalized_focus
elif isinstance(focus, str) and focus.strip():
plan["content_focus"] = [item.strip()
for item in re.split(r"[、,,;\n]", focus) if item.strip()]
if "培训" not in plan["topic"]:
plan["topic"] = f"{plan['topic']}安全培训"
if "PPT大纲" not in plan["template"]:
plan["template"] = f"{plan['template']}PPT大纲"
return plan
def _build_safety_training_generation_message(message: str, plan: dict) -> str:
payload = _extract_safety_training_request_payload(message)
focus_text = "、".join(plan.get("content_focus") or [])
lines = [
"输出类型:安全培训PPT大纲",
"请基于以下结构化需求生成安全培训PPT大纲,不要生成通知正文,不要切换到其他文档生成任务。",
f"主题:{plan.get('topic') or '安全培训'}",
f"模板:{plan.get('template') or '标准安全培训PPT大纲'}",
f"内容要点:{focus_text or '安全生产责任、风险识别、应急处置、安全意识提升'}",
f"参训对象:{plan.get('audience') or '参训员工'}",
f"培训时间:{plan.get('time') or '未指定'}",
f"培训地点:{plan.get('location') or '未指定'}",
f"培训目标:{plan.get('goal') or '提升参训人员安全意识和风险防控能力'}",
f"其他要求:{plan.get('notes') or '无'}",
f"归一化需求:{plan.get('normalized_request') or ''}",
f"原始需求:{payload['request'] or message}",
]
if payload["filename"] or payload["document_content"]:
lines.extend([
f"上传文档名称:{payload['filename'] or '未命名文档'}",
f"上传文档大小:{payload['filesize'] or '未知'}",
"上传文档内容:",
payload["document_content"] or "无",
])
return "\n".join(lines)
async def _infer_safety_training_plan(message: str) -> dict:
payload = _extract_safety_training_request_payload(message)
planning_input = payload["request"] or message
if payload["document_content"]:
planning_input = (
f"{planning_input}\n\n"
f"上传文档名称:{payload['filename'] or '未命名文档'}\n"
f"上传文档内容摘要:{payload['document_content'][:3000]}"
)
try:
response = await qwen_service.chat([
{"role": "system", "content": SAFETY_TRAINING_PLAN_SYSTEM_PROMPT},
{"role": "user", "content": planning_input},
])
return _normalize_safety_training_plan(message, _parse_json_object(response))
except Exception as e:
logger.warning(
f"[safety_training] 需求整理失败,使用兜底结构: {type(e).__name__}: {e}")
return _build_fallback_safety_training_plan(message)
def _clean_ai_writing_response(content: str) -> str:
text = str(content or "").strip()
if not text:
return ""
text = re.sub(r"```(?:html)?\s*", "", text,
flags=re.IGNORECASE).replace("```", "").strip()
body_match = re.search(
r"
]*>(.*?)", text, re.IGNORECASE | re.DOTALL)
if body_match:
text = body_match.group(1).strip()
first_content_tag = re.search(
r"<(?:article|section|main|div|h[1-6]|p|table|ul|ol)\b",
text,
re.IGNORECASE,
)
if first_content_tag and text[:first_content_tag.start()].strip():
text = text[first_content_tag.start():]
cleanup_patterns = (
r"]*>",
r"]*>",
r"",
r"]*>.*?",
r"]*>",
r"",
r"",
r"",
r"]*>",
r"]*>.*?",
)
for pattern in cleanup_patterns:
text = re.sub(pattern, "", text, flags=re.IGNORECASE | re.DOTALL)
return text.strip()
async def _generate_ai_writing_response(message: str) -> str:
rag_context = await _rag_search(message, top_k=10)
system_content = load_prompt(
"document_writing",
userMessage=message,
contextJSON=rag_context if rag_context else "暂无相关知识库内容",
)
messages = [
{"role": "system", "content": system_content},
{
"role": "user",
"content": (
"请根据上面的写作规范和我的原始需求,直接生成可放入富文本编辑器的公文正文 HTML 片段。"
"不要输出道歉、解释、DOCTYPE、html、head、body、style 或 script 标签。\n\n"
f"原始需求:\n{message}"
),
},
]
raw_response = await deepseek_service.chat(messages)
raw_thinking, raw_answer = split_thinking_and_answer(raw_response)
answer_text = _clean_ai_writing_response(raw_answer or raw_response)
# AI写作输出纯HTML文档内容,不附加思考过程(避免混入纯文本破坏HTML结构)
return answer_text
async def _generate_ppt_outline_response(message: str) -> str:
training_plan = await _infer_safety_training_plan(message)
generation_message = _build_safety_training_generation_message(
message, training_plan)
rag_context = await _rag_search(generation_message, top_k=10)
system_content = load_prompt(
"ppt_outline",
userMessage=generation_message,
contextJSON=rag_context if rag_context else "暂无相关知识库内容",
)
messages = [
{"role": "system", "content": system_content},
{"role": "user", "content": "请直接输出安全培训PPT大纲正文,从标题开始,不要解释提示词或安全规则。"},
]
raw_response = await qwen_service.chat(messages)
raw_thinking, raw_answer = split_thinking_and_answer(raw_response)
answer_text = raw_answer or raw_response
if raw_thinking:
thinking_summary = await summarize_thinking_content(
user_question=message,
raw_thinking=raw_thinking,
final_answer=answer_text,
chat_service=qwen_service,
context="ppt_outline",
)
return (
f"思考过程:\n{thinking_summary}\n\n回答:\n{answer_text}"
if thinking_summary
else answer_text
)
return answer_text
def _persist_message_pair(db: Session, conv_id: int, user, user_content: str, ai_content: str):
now_ts = int(time.time())
user_message = AIMessage(
ai_conversation_id=conv_id,
user_id=user.user_id,
type="user",
content=user_content,
created_at=now_ts,
updated_at=now_ts,
is_deleted=0,
)
db.add(user_message)
db.commit()
db.refresh(user_message)
ai_message = AIMessage(
ai_conversation_id=conv_id,
user_id=user.user_id,
type="ai",
content=ai_content,
prev_user_id=user_message.id,
created_at=now_ts,
updated_at=now_ts,
is_deleted=0,
)
db.add(ai_message)
db.commit()
db.refresh(ai_message)
return user_message, ai_message
def _build_history_messages(conv_id: int, limit: int = 10) -> list:
"""从数据库读取最近对话历史,构建 messages 列表"""
db = SessionLocal()
try:
msgs = (
db.query(AIMessage)
.filter(AIMessage.ai_conversation_id == conv_id, AIMessage.is_deleted == 0)
.order_by(AIMessage.id.desc())
.limit(limit)
.all()
)
msgs.reverse()
result = []
for m in msgs:
role = "user" if m.type == "user" else "assistant"
if m.content:
result.append({"role": role, "content": m.content})
return result
finally:
db.close()
# ─────────────────────────────────────────────────────────────────────────
# 非流式接口
# ─────────────────────────────────────────────────────────────────────────
class SendMessageRequest(BaseModel):
message: str
conversation_id: Optional[int] = None
ai_conversation_id: Optional[int] = None
business_type: int = 0 # 0=AI助手, 1=PPT大纲, 2=AI写作, 3=考试工坊
exam_name: str = ""
ai_message_id: int = 0
@router.post("/send_deepseek_message")
async def send_deepseek_message(
request: Request,
data: SendMessageRequest,
db: Session = Depends(get_db),
):
"""
发送消息(非流式)
支持多种业务类型:
- 0: AI助手(意图识别 + RAG)
- 1: PPT大纲生成
- 2: AI写作
- 3: 考试工坊
"""
user = request.state.user
if not user:
return {"statusCode": 401, "msg": "未授权"}
try:
message = data.message.strip()
if not message:
return {"statusCode": 400, "msg": "消息不能为空"}
conversation_id = data.conversation_id or data.ai_conversation_id
# 创建或获取对话
if not conversation_id:
conversation = AIConversation(
user_id=user.user_id,
content=message[:100],
business_type=data.business_type,
exam_name=data.exam_name if data.business_type == 3 else "",
created_at=int(time.time()),
updated_at=int(time.time()),
is_deleted=0,
)
db.add(conversation)
db.commit()
db.refresh(conversation)
conv_id = conversation.id
else:
conv_id = conversation_id
db.query(AIConversation).filter(
AIConversation.id == conv_id,
AIConversation.user_id == user.user_id,
AIConversation.is_deleted == 0,
).update({
"content": message[:100],
"business_type": data.business_type,
"exam_name": data.exam_name if data.business_type == 3 else "",
"updated_at": int(time.time()),
})
db.commit()
response_text = ""
ai_message_id = 0
if data.business_type == 0:
# AI助手:意图识别 + RAG
try:
intent_result = await qwen_service.intent_recognition(message)
intent_type = ""
if isinstance(intent_result, dict):
intent_type = (
intent_result.get("intent_type") or intent_result.get(
"intent") or ""
).lower()
rag_context = ""
if intent_type in ("query_knowledge_base", "知识库查询", "技术咨询"):
rag_context = await _rag_search(message, top_k=10)
# 使用prompt加载器加载最终回答prompt
system_content = load_prompt(
"final_answer",
userMessage=message,
contextJSON=rag_context if rag_context else "暂无相关知识库内容"
)
messages = [
{"role": "user", "content": system_content},
]
qwen_response = await qwen_service.chat(messages)
raw_thinking, raw_answer = split_thinking_and_answer(
qwen_response)
answer_source = raw_answer or qwen_response
# 兼容模型直接返回 JSON 的场景
answer_text = answer_source
try:
if isinstance(answer_source, str) and answer_source.strip().startswith("{"):
response_json = json.loads(answer_source)
answer_text = response_json.get(
"natural_language_answer", answer_source
)
except Exception:
answer_text = answer_source
if raw_thinking:
thinking_summary = await summarize_thinking_content(
user_question=message,
raw_thinking=raw_thinking,
final_answer=answer_text,
chat_service=qwen_service,
context="send_message",
)
response_text = (
f"思考过程:\n{thinking_summary}\n\n回答:\n{answer_text}"
if thinking_summary
else answer_text
)
else:
response_text = answer_text
except Exception as e:
error_detail = str(e).strip() if str(
e).strip() else f"未知错误({type(e).__name__})"
logger.error(
f"[send_deepseek_message] AI助手异常: {type(e).__name__}: {error_detail}")
response_text = f"处理失败: {error_detail}"
elif data.business_type == 1:
# PPT大纲生成
try:
response_text = await _generate_ppt_outline_response(message)
_, ai_message = _persist_message_pair(
db=db,
conv_id=conv_id,
user=user,
user_content=message,
ai_content=response_text,
)
ai_message_id = ai_message.id
_refresh_conversation_snapshot(db, conv_id, user.user_id)
db.commit()
return {
"statusCode": 200,
"msg": "success",
"data": {
"conversation_id": conv_id,
"ai_conversation_id": conv_id,
"response": response_text,
"reply": response_text,
"content": response_text,
"message": response_text,
"ai_message_id": ai_message_id,
"user_id": user.user_id,
"business_type": data.business_type,
},
}
except Exception as e:
error_detail = str(e).strip() if str(
e).strip() else f"未知错误({type(e).__name__})"
logger.error(
f"[send_deepseek_message] PPT大纲生成异常: {type(e).__name__}: {error_detail}")
response_text = f"处理失败: {error_detail}"
elif data.business_type == 2:
# AI写作
try:
response_text = await _generate_ai_writing_response(message)
_, ai_message = _persist_message_pair(
db=db,
conv_id=conv_id,
user=user,
user_content=message,
ai_content=response_text,
)
ai_message_id = ai_message.id
_refresh_conversation_snapshot(db, conv_id, user.user_id)
db.commit()
return {
"statusCode": 200,
"msg": "success",
"data": {
"conversation_id": conv_id,
"ai_conversation_id": conv_id,
"response": response_text,
"reply": response_text,
"content": response_text,
"message": response_text,
"ai_message_id": ai_message_id,
"user_id": user.user_id,
"business_type": data.business_type,
},
}
except Exception as e:
error_detail = str(e).strip() if str(
e).strip() else f"未知错误({type(e).__name__})"
logger.error(
f"[send_deepseek_message] AI写作异常: {type(e).__name__}: {error_detail}")
response_text = f"处理失败: {error_detail}"
elif data.business_type == 3:
# 考试工坊:生成题目
try:
system_content = (
"你是一个专业的考试题目生成助手,专注于路桥隧轨施工安全领域。\n"
"请根据用户需求生成专业的考试题目,包括单选题、多选题、判断题、简答题等。\n"
"用户消息中已经包含考试标题、题型要求和出题依据内容,必须以其中的出题依据内容为核心生成题目,不能脱离依据内容自由发挥。\n"
"题干、选项、答案和解析都要与出题依据内容中的知识点、专业术语、操作流程、规范要求或培训主题直接相关。\n"
"输出必须是可直接 JSON.parse 的纯 JSON,不要包含 markdown 代码块、解释文字或额外前后缀。\n"
"JSON 顶层结构必须包含 singleChoice、judge、multiple、short 四个字段。\n"
"singleChoice.questions 和 multiple.questions 中每道题必须包含 text、options、answer、analysis。\n"
"options 必须是数组,元素格式为 {\"key\":\"A\",\"text\":\"具体选项内容\"},禁止输出“选项A”这类占位文本。\n"
"judge.questions 中每道题必须包含 text、answer、analysis。\n"
"short.questions 中每道题必须包含 text、outline,其中 outline 至少包含 keyFactors。\n"
"所有题目内容、选项内容、答案和解析都要结合用户给出的工程类型、题型数量、分值和课件内容具体生成。"
)
messages = [
{"role": "system", "content": system_content},
{"role": "user", "content": message},
]
raw_response_text = await qwen_service.chat(messages)
response_text = _sanitize_exam_response(raw_response_text)
now_ts = int(time.time())
user_message = AIMessage(
ai_conversation_id=conv_id,
user_id=user.user_id,
type="user",
content=message,
created_at=now_ts,
updated_at=now_ts,
is_deleted=0,
)
db.add(user_message)
db.commit()
db.refresh(user_message)
ai_message = AIMessage(
ai_conversation_id=conv_id,
user_id=user.user_id,
type="ai",
content=response_text,
prev_user_id=user_message.id,
created_at=now_ts,
updated_at=now_ts,
is_deleted=0,
)
db.add(ai_message)
db.commit()
_refresh_conversation_snapshot(db, conv_id, user.user_id)
db.commit()
generated_title = data.exam_name
if not generated_title:
try:
import json
exam_data = json.loads(response_text)
generated_title = exam_data.get("title") or exam_data.get("exam_name") or exam_data.get(
"examTitle") or exam_data.get("试卷标题") or exam_data.get("标题") or ""
except Exception:
pass
if generated_title:
db.query(AIConversation).filter(AIConversation.id == conv_id).update(
{"exam_name": generated_title,
"updated_at": int(time.time())}
)
db.commit()
except Exception as e:
error_detail = str(e).strip() if str(
e).strip() else f"未知错误({type(e).__name__})"
logger.error(
f"[send_deepseek_message] 考试工坊异常: {type(e).__name__}: {error_detail}")
response_text = f"处理失败: {error_detail}"
else:
return {"statusCode": 400, "msg": f"不支持的业务类型: {data.business_type}"}
return {
"statusCode": 200,
"msg": "success",
"data": {
"conversation_id": conv_id,
"ai_conversation_id": conv_id,
"response": response_text,
"reply": response_text,
"content": response_text,
"message": response_text,
"user_id": user.user_id,
"business_type": data.business_type,
},
}
except Exception as e:
logger.error(f"[send_deepseek_message] 异常: {e}")
return {"statusCode": 500, "msg": f"处理失败: {str(e)}"}
@router.get("/get_history_record")
async def get_history_record(
request: Request,
ai_conversation_id: int = 0,
business_type: Optional[int] = None,
db: Session = Depends(get_db),
):
"""兼容前端的历史记录查询:ai_conversation_id=0 返回对话列表,否则返回消息详情。"""
user = request.state.user
if not user:
return {"statusCode": 401, "msg": "未授权"}
if ai_conversation_id > 0:
messages = (
db.query(AIMessage)
.filter(
AIMessage.ai_conversation_id == ai_conversation_id,
AIMessage.user_id == user.user_id,
AIMessage.is_deleted == 0,
)
.order_by(AIMessage.id.asc())
.all()
)
return {
"statusCode": 200,
"msg": "success",
"total": len(messages),
"data": [
{
"id": message.id,
"ai_conversation_id": message.ai_conversation_id,
"user_id": message.user_id,
"type": message.type,
"content": message.content,
"user_feedback": message.user_feedback,
"prev_user_id": message.prev_user_id,
"search_source": message.search_source or "",
"guess_you_want": message.guess_you_want or "",
"created_at": _to_frontend_timestamp(message.created_at),
"updated_at": _to_frontend_timestamp(message.updated_at),
}
for message in messages
],
}
conversations_query = db.query(AIConversation).filter(
AIConversation.user_id == user.user_id,
AIConversation.is_deleted == 0,
)
if business_type is not None:
conversations_query = conversations_query.filter(
AIConversation.business_type == business_type
)
total = conversations_query.count()
conversations = (
conversations_query
.order_by(AIConversation.updated_at.desc(), AIConversation.id.desc())
.limit(50)
.all()
)
return {
"statusCode": 200,
"msg": "success",
"total": total,
"data": [
{
"id": conv.id,
"title": _build_conversation_title(conv),
"content": conv.content or "",
"business_type": conv.business_type,
"exam_name": conv.exam_name or "",
"created_at": _to_frontend_timestamp(conv.created_at),
"updated_at": _to_frontend_timestamp(conv.updated_at),
}
for conv in conversations
],
}
class DeleteConversationRequest(BaseModel):
ai_conversation_id: int = 0
ai_message_id: int = 0
@router.post("/delete_conversation")
async def delete_conversation(
request: Request, data: DeleteConversationRequest, db: Session = Depends(get_db)
):
"""
删除对话(软删除)
同时软删除对话记录和所有关联的消息
"""
user = request.state.user
if not user:
return {"statusCode": 401, "msg": "未授权"}
now_ts = int(time.time())
if data.ai_message_id:
ai_message = (
db.query(AIMessage)
.filter(
AIMessage.id == data.ai_message_id,
AIMessage.user_id == user.user_id,
AIMessage.type == "ai",
AIMessage.is_deleted == 0,
)
.first()
)
if not ai_message:
return {"statusCode": 404, "msg": "消息不存在"}
db.query(AIMessage).filter(
AIMessage.id == ai_message.id,
AIMessage.user_id == user.user_id,
).update({"is_deleted": 1, "updated_at": now_ts})
if ai_message.prev_user_id:
db.query(AIMessage).filter(
AIMessage.id == ai_message.prev_user_id,
AIMessage.user_id == user.user_id,
AIMessage.ai_conversation_id == ai_message.ai_conversation_id,
).update({"is_deleted": 1, "updated_at": now_ts})
_refresh_conversation_snapshot(
db, ai_message.ai_conversation_id, user.user_id)
db.commit()
return {"statusCode": 200, "msg": "删除成功"}
if not data.ai_conversation_id:
return {"statusCode": 400, "msg": "缺少删除参数"}
db.query(AIConversation).filter(
AIConversation.id == data.ai_conversation_id,
AIConversation.user_id == user.user_id,
).update({"is_deleted": 1, "updated_at": now_ts})
db.query(AIMessage).filter(
AIMessage.ai_conversation_id == data.ai_conversation_id,
AIMessage.user_id == user.user_id,
).update({"is_deleted": 1, "updated_at": now_ts})
db.commit()
return {"statusCode": 200, "msg": "删除成功"}
class DeleteHistoryRequest(BaseModel):
ai_conversation_id: int
@router.post("/delete_history_record")
async def delete_history_record(
request: Request, data: DeleteHistoryRequest, db: Session = Depends(get_db)
):
"""删除历史记录(软删除)"""
user = request.state.user
if not user:
return {"statusCode": 401, "msg": "未授权"}
db.query(AIConversation).filter(
AIConversation.id == data.ai_conversation_id,
AIConversation.user_id == user.user_id,
).update({"is_deleted": 1, "updated_at": int(time.time())})
db.commit()
return {"statusCode": 200, "msg": "删除成功"}
# ─────────────────────────────────────────────────────────────────────────
# 流式接口 /stream/chat(无 DB,意图识别 + RAG)
# ─────────────────────────────────────────────────────────────────────────
class StreamChatRequest(BaseModel):
message: str
model: str = ""
@router.post("/stream/chat")
async def stream_chat(request: Request, data: StreamChatRequest):
"""流式聊天(SSE,不写 DB)"""
message = data.message.strip()
if not message:
return JSONResponse(content={"statusCode": 400, "msg": "消息不能为空"})
async def event_generator():
intent_type = ""
try:
intent_result = await qwen_service.intent_recognition(message)
if isinstance(intent_result, dict):
intent_type = (
intent_result.get("intent_type") or intent_result.get(
"intent") or ""
).lower()
except Exception as ie:
logger.warning(f"[stream/chat] 意图识别异常: {ie}")
rag_context = ""
if intent_type in ("query_knowledge_base", "知识库查询", "技术咨询"):
rag_context = await _rag_search(message)
# 使用prompt加载器加载最终回答prompt
system_content = load_prompt(
"final_answer",
userMessage=message,
contextJSON=rag_context if rag_context else "暂无相关知识库内容"
)
messages = [
{"role": "user", "content": system_content},
]
try:
buffer = ""
pre_answer = ""
thinking_buf = ""
in_think = False
thinking_done = False
max_input_chars = getattr(
settings.thinking_summary, "max_input_chars", 1500)
async for chunk in qwen_service.stream_chat(messages):
buffer += chunk
while buffer:
lower = buffer.lower()
if not thinking_done:
if not in_think:
start_idx = lower.find("")
if start_idx == -1:
yield f"data: {json.dumps({'content': buffer}, ensure_ascii=False)}\n\n"
buffer = ""
break
pre_answer += buffer[:start_idx]
buffer = buffer[start_idx + len(""):]
in_think = True
continue
end_idx = lower.find("")
if end_idx == -1:
if max_input_chars and len(thinking_buf) < max_input_chars:
thinking_buf += buffer[: max_input_chars -
len(thinking_buf)]
buffer = ""
break
if max_input_chars and len(thinking_buf) < max_input_chars:
thinking_part = buffer[:end_idx]
thinking_buf += thinking_part[: max_input_chars - len(
thinking_buf)]
buffer = buffer[end_idx + len(""):]
in_think = False
thinking_done = True
thinking_summary = await summarize_thinking_content(
user_question=message,
raw_thinking=thinking_buf,
final_answer="",
chat_service=qwen_service,
context="stream_chat",
)
if thinking_summary:
prefix = f"思考过程:\n{thinking_summary}\n\n回答:\n"
yield f"data: {json.dumps({'content': prefix}, ensure_ascii=False)}\n\n"
answer_chunk = (pre_answer + buffer).lstrip()
if answer_chunk:
yield f"data: {json.dumps({'content': answer_chunk}, ensure_ascii=False)}\n\n"
pre_answer = ""
buffer = ""
break
yield f"data: {json.dumps({'content': buffer}, ensure_ascii=False)}\n\n"
buffer = ""
# 流结束但未遇到 :仅尝试生成要点(不回退输出 raw thinking)
if in_think and not thinking_done and thinking_buf:
thinking_summary = await summarize_thinking_content(
user_question=message,
raw_thinking=thinking_buf,
final_answer="",
chat_service=qwen_service,
context="stream_chat_eof",
)
if thinking_summary:
prefix = f"思考过程:\n{thinking_summary}\n\n回答:\n"
yield f"data: {json.dumps({'content': prefix}, ensure_ascii=False)}\n\n"
if pre_answer:
yield f"data: {json.dumps({'content': pre_answer}, ensure_ascii=False)}\n\n"
except Exception as e:
logger.error(f"[stream/chat] 流式输出异常: {e}")
yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
finally:
yield "data: [DONE]\n\n"
return StreamingResponse(event_generator(), media_type="text/event-stream")
# ─────────────────────────────────────────────────────────────────────────
# 流式接口 /stream/chat-with-db(前端主聊天接口)
# ─────────────────────────────────────────────────────────────────────────
class StreamChatWithDBRequest(BaseModel):
message: str
ai_conversation_id: int = 0
business_type: int = 0
exam_name: str = ""
ai_message_id: int = 0
online_search_content: str = ""
@router.post("/stream/chat-with-db")
async def stream_chat_with_db(request: Request, data: StreamChatWithDBRequest):
"""
带 DB 操作的流式聊天(SSE)
流程:
1. 创建/获取对话
2. 插入用户消息和 AI 占位消息
3. 发送 initial 事件
4. RAG 检索
5. 构建历史上下文
6. 流式输出
7. 更新 AI 消息内容
"""
user = request.state.user
if not user:
return JSONResponse(content={"statusCode": 401, "msg": "未授权"})
message = data.message.strip()
if not message:
return JSONResponse(content={"statusCode": 400, "msg": "消息不能为空"})
async def event_generator():
db = SessionLocal()
try:
# 1. 创建或获取对话
if data.ai_conversation_id == 0:
conversation = AIConversation(
user_id=user.user_id,
content=_build_conversation_preview(message, limit=100),
business_type=data.business_type,
exam_name=data.exam_name,
created_at=int(time.time()),
updated_at=int(time.time()),
is_deleted=0,
)
db.add(conversation)
db.commit()
db.refresh(conversation)
conv_id = conversation.id
else:
existing_conversation = (
db.query(AIConversation)
.filter(
AIConversation.id == data.ai_conversation_id,
AIConversation.user_id == user.user_id,
AIConversation.is_deleted == 0,
)
.first()
)
if existing_conversation:
conv_id = existing_conversation.id
db.query(AIConversation).filter(
AIConversation.id == conv_id,
AIConversation.user_id == user.user_id,
).update(
{
"content": _build_conversation_preview(message, limit=100),
"business_type": data.business_type,
"exam_name": data.exam_name if data.business_type == 3 else "",
"updated_at": int(time.time()),
}
)
db.commit()
else:
conversation = AIConversation(
user_id=user.user_id,
content=_build_conversation_preview(
message, limit=100),
business_type=data.business_type,
exam_name=data.exam_name if data.business_type == 3 else "",
created_at=int(time.time()),
updated_at=int(time.time()),
is_deleted=0,
)
db.add(conversation)
db.commit()
db.refresh(conversation)
conv_id = conversation.id
# 2. 插入用户消息
user_msg = AIMessage(
ai_conversation_id=conv_id,
user_id=user.user_id,
type="user",
content=message,
created_at=int(time.time()),
updated_at=int(time.time()),
is_deleted=0,
)
db.add(user_msg)
db.commit()
db.refresh(user_msg)
# 3. 插入 AI 占位消息
ai_msg = AIMessage(
ai_conversation_id=conv_id,
user_id=user.user_id,
type="ai",
content="",
prev_user_id=user_msg.id,
created_at=int(time.time()),
updated_at=int(time.time()),
is_deleted=0,
)
db.add(ai_msg)
db.commit()
db.refresh(ai_msg)
# 4. 发送 initial 事件
yield f"data: {json.dumps({'type': 'initial', 'ai_conversation_id': conv_id, 'ai_message_id': ai_msg.id}, ensure_ascii=False)}\n\n"
# 5. RAG search
rag_context = await _rag_search(message, top_k=10)
if data.business_type in (1, 2):
# PPT outline / AI writing: use dedicated prompt
prompt_name = "ppt_outline" if data.business_type == 1 else "document_writing"
system_content = load_prompt(
prompt_name,
userMessage=message,
contextJSON=rag_context if rag_context else "?????????"
)
messages = [
{"role": "user", "content": system_content},
]
else:
# 6. History context (last 4 items, 2 turns)
history_msgs = (
db.query(AIMessage)
.filter(
AIMessage.ai_conversation_id == conv_id,
AIMessage.id < ai_msg.id,
AIMessage.is_deleted == 0,
)
.order_by(AIMessage.updated_at.desc())
.limit(4)
.all()
)
history_msgs.reverse()
history_context = ""
for msg in history_msgs:
role = "??" if msg.type == "user" else "??"
history_context += f"{role}: {msg.content}\n\n"
# 7. Build final prompt
context_parts = []
if rag_context:
context_parts.append(f"??????\n{rag_context}")
if data.online_search_content:
context_parts.append(
f"???????\n{data.online_search_content}")
context_json = "\n\n".join(
context_parts) if context_parts else "?????????"
system_content = load_prompt(
"final_answer",
userMessage=message,
contextJSON=context_json,
historyContext=history_context if history_context else ""
)
messages = [
{"role": "user", "content": system_content},
]
# 8. 流式输出并收集完整回复
full_response = ""
try:
summary_enabled = getattr(
settings.thinking_summary, "enabled", True)
max_input_chars = getattr(
settings.thinking_summary, "max_input_chars", 1500)
buffer = ""
pre_answer = ""
thinking_buf = ""
in_think = False
thinking_done = False
async for chunk in qwen_service.stream_chat(messages):
if not summary_enabled:
escaped_chunk = chunk.replace("\n", "\\n")
full_response += chunk
yield f"data: {escaped_chunk}\n\n"
continue
buffer += chunk
while buffer:
lower = buffer.lower()
if not thinking_done:
if not in_think:
start_idx = lower.find("")
if start_idx == -1:
escaped_text = buffer.replace("\n", "\\n")
full_response += buffer
yield f"data: {escaped_text}\n\n"
buffer = ""
break
pre_answer += buffer[:start_idx]
buffer = buffer[start_idx + len(""):]
in_think = True
continue
end_idx = lower.find("")
if end_idx == -1:
if max_input_chars and len(thinking_buf) < max_input_chars:
thinking_buf += buffer[: max_input_chars -
len(thinking_buf)]
buffer = ""
break
if max_input_chars and len(thinking_buf) < max_input_chars:
thinking_part = buffer[:end_idx]
thinking_buf += thinking_part[: max_input_chars - len(
thinking_buf)]
buffer = buffer[end_idx + len(""):]
in_think = False
thinking_done = True
thinking_summary = await summarize_thinking_content(
user_question=message,
raw_thinking=thinking_buf,
final_answer="",
chat_service=qwen_service,
context="stream_chat_with_db",
)
if thinking_summary:
prefix = f"思考过程:\n{thinking_summary}\n\n回答:\n"
full_response += prefix
escaped_prefix = prefix.replace('\n', '\\n')
yield f"data: {escaped_prefix}\n\n"
answer_chunk = (pre_answer + buffer).lstrip()
if answer_chunk:
full_response += answer_chunk
escaped_answer = answer_chunk.replace(
'\n', '\\n')
yield f"data: {escaped_answer}\n\n"
pre_answer = ""
buffer = ""
break
escaped_text = buffer.replace("\n", "\\n")
full_response += buffer
yield f"data: {escaped_text}\n\n"
buffer = ""
except Exception as e:
logger.error(f"[stream/chat-with-db] 流式输出异常: {e}")
yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
# 流结束但未遇到 :仅尝试生成要点(不回退输出 raw thinking)
if summary_enabled and in_think and not thinking_done and thinking_buf:
thinking_summary = await summarize_thinking_content(
user_question=message,
raw_thinking=thinking_buf,
final_answer="",
chat_service=qwen_service,
context="stream_chat_with_db_eof",
)
if thinking_summary:
prefix = f"思考过程:\n{thinking_summary}\n\n回答:\n"
full_response += prefix
escaped_prefix = prefix.replace('\n', '\\n')
yield f"data: {escaped_prefix}\n\n"
if pre_answer:
full_response += pre_answer
escaped_pre_answer = pre_answer.replace('\n', '\\n')
yield f"data: {escaped_pre_answer}\n\n"
# 9. 更新 AI 消息内容
if full_response:
now_ts = int(time.time())
db.query(AIMessage).filter(AIMessage.id == ai_msg.id).update(
{"content": full_response, "updated_at": now_ts}
)
db.query(AIConversation).filter(
AIConversation.id == conv_id,
AIConversation.user_id == user.user_id,
).update(
{
"content": _build_conversation_preview(message, limit=100),
"business_type": data.business_type,
"exam_name": data.exam_name if data.business_type == 3 else "",
"updated_at": now_ts,
}
)
db.commit()
# 10. 结束标记
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(f"[stream/chat-with-db] 处理异常: {e}")
yield f"data: {json.dumps({'error': str(e)}, ensure_ascii=False)}\n\n"
finally:
db.close()
return StreamingResponse(event_generator(), media_type="text/event-stream")
# ─────────────────────────────────────────────────────────────────────────
# 猜你想问
# ─────────────────────────────────────────────────────────────────────────
class GuessYouWantRequest(BaseModel):
ai_message_id: int
@router.post("/guess_you_want")
async def guess_you_want(
request: Request,
data: GuessYouWantRequest,
db: Session = Depends(get_db),
):
"""生成"猜你想问"的3个关联问题,保存到 AIMessage.guess_you_want"""
user = request.state.user
if not user:
return {"statusCode": 401, "msg": "未授权"}
try:
ai_msg = (
db.query(AIMessage)
.filter(AIMessage.id == data.ai_message_id, AIMessage.is_deleted == 0)
.first()
)
if not ai_msg:
return {"statusCode": 404, "msg": "消息不存在"}
# 使用prompt加载器加载猜你想问prompt
system_content = load_prompt(
"guess_questions",
currentContent=ai_msg.content[:500]
)
messages = [
{"role": "user", "content": system_content},
]
response = await qwen_service.chat(messages)
try:
# 尝试从响应中提取 JSON
json_match = re.search(
r'\{[^{}]*"questions"[^{}]*\}', response, re.DOTALL)
if json_match:
response_json = json.loads(json_match.group())
else:
response_json = json.loads(response)
questions = response_json.get("questions", [])
except Exception:
lines = [l.strip() for l in response.split("\n") if l.strip()]
questions = []
for line in lines:
clean = line.lstrip("0123456789.-、 ").strip()
if clean and len(clean) > 5:
questions.append(clean)
if not questions:
questions = ["该话题的具体应用场景?", "有哪些注意事项?", "相关案例分析?"]
questions = _finalize_related_questions(
questions, ai_msg.content, limit=3)
guess_json = json.dumps({"questions": questions}, ensure_ascii=False)
db.query(AIMessage).filter(AIMessage.id == data.ai_message_id).update(
{"guess_you_want": guess_json, "updated_at": int(time.time())}
)
db.commit()
return {
"statusCode": 200,
"msg": "success",
"data": {"ai_message_id": data.ai_message_id, "questions": questions},
}
except Exception as e:
logger.error(f"[guess_you_want] 处理异常: {e}")
return {"statusCode": 500, "msg": f"处理失败: {str(e)}"}
# ─────────────────────────────────────────────────────────────────────────
# 在线搜索(Dify 工作流集成)
# ─────────────────────────────────────────────────────────────────────────
@router.get("/online_search")
async def online_search(question: str, request: Request, db: Session = Depends(get_db)):
"""
在线搜索
流程:Qwen 提炼关键词 → Dify 工作流 → 返回摘要
"""
user = request.state.user
if not user:
return {"statusCode": 401, "msg": "未授权"}
try:
keywords = await qwen_service.extract_keywords(question)
dify_config = getattr(settings, "dify", None)
if not dify_config or not getattr(dify_config, "workflow_url", None):
return {"statusCode": 500, "msg": "Dify 配置未设置"}
headers = {
"Authorization": f"Bearer {dify_config.auth_token}",
"Content-Type": "application/json",
}
payload = {
"workflow_id": dify_config.workflow_id,
"inputs": {
"keywords": keywords,
"num": 5, # 搜索结果数量
"max_text_len": 4000 # 最大文本长度
},
"response_mode": "blocking",
"user": getattr(user, "account", str(user.user_id)),
}
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(dify_config.workflow_url, headers=headers, json=payload)
if resp.status_code != 200:
logger.error(
f"[online_search] Dify 调用失败: {resp.status_code}, 响应: {resp.text}")
return {"statusCode": 500, "msg": f"搜索服务异常: {resp.status_code}"}
result = resp.json()
search_text = result.get("data", {}).get(
"outputs", {}).get("text", "")
return {
"statusCode": 200,
"msg": "success",
"data": {"keywords": keywords, "result": search_text},
}
except Exception as e:
logger.error(f"[online_search] 处理异常: {e}")
return {"statusCode": 500, "msg": f"搜索失败: {str(e)}"}
class SaveOnlineSearchResultRequest(BaseModel):
ai_message_id: int
search_result: str
@router.post("/save_online_search_result")
async def save_online_search_result(
request: Request,
data: SaveOnlineSearchResultRequest,
db: Session = Depends(get_db),
):
"""保存联网搜索结果到 AIMessage.search_source"""
user = request.state.user
if not user:
return {"statusCode": 401, "msg": "未授权"}
try:
db.query(AIMessage).filter(AIMessage.id == data.ai_message_id).update(
{"search_source": data.search_result,
"updated_at": int(time.time())}
)
db.commit()
return {"statusCode": 200, "msg": "保存成功"}
except Exception as e:
logger.error(f"[save_online_search_result] 处理异常: {e}")
return {"statusCode": 500, "msg": f"保存失败: {str(e)}"}
# ─────────────────────────────────────────────────────────────────────────
# 意图识别独立接口
# ─────────────────────────────────────────────────────────────────────────
class IntentRecognitionRequest(BaseModel):
message: str
save_to_db: bool = False
ai_conversation_id: int = 0
scene: str = "default"
@router.post("/intent_recognition")
async def intent_recognition(
request: Request,
data: IntentRecognitionRequest,
db: Session = Depends(get_db),
):
"""独立意图识别接口;若为 greeting/faq 且 save_to_db=True 则直接存 DB"""
user = request.state.user
if not user:
return {"statusCode": 401, "msg": "未授权"}
try:
if (data.scene or "").strip().lower() == "module_dispatch":
dispatch_result = await qwen_service.module_dispatch_recognition(data.message)
return {
"statusCode": 200,
"msg": "success",
"data": {
"route_mode": dispatch_result.get("route_mode", "ai-qa"),
"business_type": dispatch_result.get("business_type", 0),
"confidence": dispatch_result.get("confidence", 0.5),
"reason": dispatch_result.get("reason", ""),
"saved_to_db": False,
},
}
intent_result = await qwen_service.intent_recognition(data.message)
intent_type = ""
response_text = ""
if isinstance(intent_result, dict):
intent_type = (
intent_result.get("intent_type") or intent_result.get(
"intent") or ""
).lower()
response_text = intent_result.get("response", "")
if data.save_to_db and intent_type in ("greeting", "问候", "faq", "常见问题"):
if data.ai_conversation_id == 0:
conversation = AIConversation(
user_id=user.user_id,
content=data.message[:100],
business_type=0,
created_at=int(time.time()),
updated_at=int(time.time()),
is_deleted=0,
)
db.add(conversation)
db.commit()
db.refresh(conversation)
conv_id = conversation.id
else:
conv_id = data.ai_conversation_id
user_msg = AIMessage(
ai_conversation_id=conv_id,
user_id=user.user_id,
type="user",
content=data.message,
created_at=int(time.time()),
updated_at=int(time.time()),
is_deleted=0,
)
db.add(user_msg)
db.commit()
ai_msg = AIMessage(
ai_conversation_id=conv_id,
user_id=user.user_id,
type="ai",
content=response_text,
prev_user_id=user_msg.id,
created_at=int(time.time()),
updated_at=int(time.time()),
is_deleted=0,
)
db.add(ai_msg)
db.commit()
db.refresh(ai_msg)
return {
"statusCode": 200,
"msg": "success",
"data": {
"intent_type": intent_type,
"response": response_text,
"ai_conversation_id": conv_id,
"ai_message_id": ai_msg.id,
"saved_to_db": True,
},
}
return {
"statusCode": 200,
"msg": "success",
"data": {
"intent_type": intent_type,
"response": response_text,
"saved_to_db": False,
},
}
except Exception as e:
logger.error(f"[intent_recognition] 处理异常: {e}")
return {"statusCode": 500, "msg": f"处理失败: {str(e)}"}
# ─────────────────────────────────────────────────────────────────────────
# 获取用户推荐问题(模糊查询 QA / RecommendQuestion 表)
# ─────────────────────────────────────────────────────────────────────────
@router.get("/get_user_recommend_question")
async def get_user_recommend_question(
keyword: str = "",
limit: int = 10,
db: Session = Depends(get_db),
):
"""获取推荐问题(支持模糊查询)"""
try:
query = db.query(RecommendQuestion).filter(
RecommendQuestion.is_deleted == 0)
if keyword:
query = query.filter(
RecommendQuestion.question.like(f"%{keyword}%"))
questions = query.order_by(
RecommendQuestion.id.desc()).limit(limit).all()
return {
"statusCode": 200,
"msg": "success",
"data": [
{"id": q.id, "question": q.question, "created_at": q.created_at}
for q in questions
],
}
except Exception as e:
logger.error(f"[get_user_recommend_question] 处理异常: {e}")
return {"statusCode": 500, "msg": f"查询失败: {str(e)}"}
# ─────────────────────────────────────────────────────────────────────────
# PPT 大纲 / 文档编辑保存
# ─────────────────────────────────────────────────────────────────────────
class SavePPTOutlineRequest(BaseModel):
ai_message_id: int
content: str
@router.post("/save_ppt_outline")
async def save_ppt_outline(
request: Request,
data: SavePPTOutlineRequest,
db: Session = Depends(get_db),
):
"""更新 AIMessage.content 保存 PPT 大纲内容"""
user = request.state.user
if not user:
return {"statusCode": 401, "msg": "未授权"}
try:
db.query(AIMessage).filter(AIMessage.id == data.ai_message_id).update(
{"content": data.content, "updated_at": int(time.time())}
)
db.commit()
return {"statusCode": 200, "msg": "保存成功"}
except Exception as e:
logger.error(f"[save_ppt_outline] 处理异常: {e}")
return {"statusCode": 500, "msg": f"保存失败: {str(e)}"}
class SaveEditDocumentRequest(BaseModel):
ai_message_id: int
content: str
@router.post("/save_edit_document")
async def save_edit_document(
request: Request,
data: SaveEditDocumentRequest,
db: Session = Depends(get_db),
):
"""更新 ai 类型 AIMessage.content(AI写作编辑保存)"""
user = request.state.user
if not user:
return {"statusCode": 401, "msg": "未授权"}
try:
db.query(AIMessage).filter(
AIMessage.id == data.ai_message_id,
AIMessage.type == "ai",
).update({"content": data.content, "updated_at": int(time.time())})
db.commit()
return {"statusCode": 200, "msg": "保存成功"}
except Exception as e:
logger.error(f"[save_edit_document] 处理异常: {e}")
return {"statusCode": 500, "msg": f"保存失败: {str(e)}"}