rag_debug_api.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593
  1. """
  2. RAG 链路调试 API 端点
  3. 提供 RAG 检索链路各环节的独立调试端点,支持单步执行和链式执行。
  4. 移植自 utils_test/RAG_Test/rag_pipeline_web/rag_pipeline_server.py
  5. 端点:
  6. - POST /debug/rag/step — 单环节调试
  7. - POST /debug/rag/chain — 链式执行
  8. - POST /debug/rag/pipeline — 完整 RAG 链路
  9. - POST /debug/rag/native — Native RAG
  10. - POST /debug/rag/parent-child — 父子文档模式
  11. - POST /debug/rag/professional-review — 专业性审查
  12. - POST /debug/rag/init — 初始化 Milvus
  13. - GET /debug/rag/data — 获取最新 pipeline 数据
  14. """
  15. import asyncio
  16. import json
  17. import logging
  18. import os
  19. import time
  20. from typing import Any, Dict, List, Optional
  21. from fastapi import APIRouter, HTTPException
  22. from pydantic import BaseModel, Field
  23. from core.construction_review.component.ai_review_engine import AIReviewEngine
  24. from core.construction_review.component.infrastructure.milvus import MilvusConfig, MilvusManager
  25. from core.construction_review.component.infrastructure.parent_tool import (
  26. enhance_with_parent_docs_grouped,
  27. extract_query_pairs_results,
  28. )
  29. from core.base.task_models import TaskFileInfo
  30. from foundation.ai.rag.retrieval.entities_enhance import entity_enhance
  31. from foundation.ai.rag.retrieval.query_rewrite import query_rewrite_manager
  32. from foundation.ai.rag.retrieval.retrieval import retrieval_manager
  33. from foundation.observability.monitoring.rag import rag_monitor
  34. logger = logging.getLogger(__name__)
  35. project_root = os.path.join(os.path.dirname(__file__), "..", "..")
  36. # ============ 懒加载全局实例 ============
  37. _milvus_manager: Optional[MilvusManager] = None
  38. _ai_review_engine: Optional[AIReviewEngine] = None
  39. def _get_milvus():
  40. global _milvus_manager
  41. if _milvus_manager is None:
  42. _milvus_manager = MilvusManager(MilvusConfig())
  43. return _milvus_manager
  44. def _get_ai_review_engine():
  45. global _ai_review_engine
  46. if _ai_review_engine is None:
  47. file_info_dict = {
  48. 'file_id': "test_file_id",
  49. 'callback_task_id': "test_task_id",
  50. 'user_id': "test_user",
  51. 'file_name': "test.docx",
  52. 'file_type': 'docx',
  53. 'file_content': b'',
  54. 'review_config': [],
  55. 'review_item_config': {}
  56. }
  57. _ai_review_engine = AIReviewEngine(TaskFileInfo(file_info_dict))
  58. return _ai_review_engine
  59. # ============ Pydantic 模型 ============
  60. class DebugStepRequest(BaseModel):
  61. step: str = Field(..., description="环节名称: query_extract, entity_enhance, multi_stage_recall, hybrid_search, parent_doc_enhance, extract_results")
  62. content: str = Field(default="", description="输入文本或 JSON 数据")
  63. params: Optional[Dict[str, Any]] = Field(default_factory=dict, description="额外参数")
  64. class DebugChainRequest(BaseModel):
  65. content: str = Field(..., description="输入文本")
  66. params: Optional[Dict[str, Any]] = Field(default_factory=dict, description="额外参数")
  67. class RagRequest(BaseModel):
  68. content: str = Field(..., description="输入文本")
  69. collection_name: Optional[str] = Field(default="rag_children_hybrid")
  70. hybrid_top_k: Optional[int] = Field(default=20)
  71. top_k: Optional[int] = Field(default=5)
  72. parent_score_threshold: Optional[float] = Field(default=0.3)
  73. max_parents: Optional[int] = Field(default=3)
  74. class ProfessionalReviewRequest(BaseModel):
  75. content: str = Field(..., description="待审查内容")
  76. check_type: Optional[str] = Field(default="both", description="non_parameter, parameter, both")
  77. # ============ 工具函数 ============
  78. def _serialize_results(results):
  79. if not results:
  80. return []
  81. out = []
  82. for item in results:
  83. if not isinstance(item, dict):
  84. out.append(str(item))
  85. continue
  86. d = {}
  87. for k, v in item.items():
  88. if k == 'metadata' and isinstance(v, dict):
  89. d[k] = {mk: str(mv) for mk, mv in v.items()}
  90. elif isinstance(v, (str, int, float, bool, list, type(None))):
  91. d[k] = v
  92. else:
  93. d[k] = str(v)
  94. out.append(d)
  95. return out
  96. def _serialize_parent_docs(parent_docs):
  97. out = []
  98. for p in parent_docs:
  99. d = {}
  100. for k, v in p.items():
  101. if k == 'metadata' and isinstance(v, dict):
  102. d[k] = {mk: str(mv) for mk, mv in v.items()}
  103. elif k == 'text_content' and isinstance(v, str):
  104. d[k] = v[:500] + '...' if len(v) > 500 else v
  105. elif isinstance(v, (str, int, float, bool, list, type(None))):
  106. d[k] = v
  107. else:
  108. d[k] = str(v)
  109. out.append(d)
  110. return out
  111. def _parse_json_param(content: str, params: dict, key: str):
  112. if key in params and params[key] is not None:
  113. return params[key]
  114. if content:
  115. try:
  116. parsed = json.loads(content)
  117. if isinstance(parsed, list):
  118. return parsed
  119. if isinstance(parsed, dict) and key in parsed:
  120. return parsed[key]
  121. except (json.JSONDecodeError, TypeError):
  122. pass
  123. return None
  124. # ============ 核心逻辑(同步) ============
  125. def _debug_step_sync(step_name: str, content: str, params: dict) -> dict:
  126. start_time = time.time()
  127. try:
  128. if step_name == 'query_extract':
  129. input_summary = {"content_length": len(content)}
  130. output = query_rewrite_manager.query_extract(content)
  131. return {"status": "success", "step": step_name, "input_summary": input_summary,
  132. "output": output, "execution_time": round(time.time() - start_time, 3)}
  133. elif step_name == 'entity_enhance':
  134. query_pairs = _parse_json_param(content, params, 'query_pairs')
  135. if query_pairs is None:
  136. return {"status": "error", "step": step_name, "error": "请提供 query_pairs(在 params 或 content 中传入 JSON)"}
  137. input_summary = {"query_pairs_count": len(query_pairs)}
  138. output = entity_enhance.entities_enhance_retrieval(query_pairs)
  139. return {"status": "success", "step": step_name, "input_summary": input_summary,
  140. "output": output, "execution_time": round(time.time() - start_time, 3)}
  141. elif step_name == 'multi_stage_recall':
  142. collection_name = params.get('collection_name', 'rag_children_hybrid')
  143. hybrid_top_k = params.get('hybrid_top_k', 50)
  144. top_k = params.get('top_k', 10)
  145. input_summary = {"content_length": len(content), "collection_name": collection_name,
  146. "hybrid_top_k": hybrid_top_k, "top_k": top_k}
  147. output = retrieval_manager.multi_stage_recall(
  148. collection_name=collection_name, query_text=content,
  149. hybrid_top_k=hybrid_top_k, top_k=top_k)
  150. return {"status": "success", "step": step_name, "input_summary": input_summary,
  151. "output": _serialize_results(output), "execution_time": round(time.time() - start_time, 3)}
  152. elif step_name == 'hybrid_search':
  153. collection_name = params.get('collection_name', 'rag_children_hybrid')
  154. top_k = params.get('top_k', 10)
  155. dense_weight = params.get('dense_weight', 0.7)
  156. sparse_weight = params.get('sparse_weight', 0.3)
  157. input_summary = {"content_length": len(content), "collection_name": collection_name,
  158. "top_k": top_k, "dense_weight": dense_weight, "sparse_weight": sparse_weight}
  159. output = retrieval_manager.hybrid_search_recall(
  160. collection_name=collection_name, query_text=content,
  161. top_k=top_k, dense_weight=dense_weight, sparse_weight=sparse_weight)
  162. return {"status": "success", "step": step_name, "input_summary": input_summary,
  163. "output": _serialize_results(output), "execution_time": round(time.time() - start_time, 3)}
  164. elif step_name == 'parent_doc_enhance':
  165. mgr = _get_milvus()
  166. bfp_result_lists = _parse_json_param(content, params, 'bfp_result_lists')
  167. if bfp_result_lists is None:
  168. return {"status": "error", "step": step_name, "error": "请提供 bfp_result_lists(在 params 或 content 中传入 JSON)"}
  169. score_threshold = params.get('score_threshold', 0.3)
  170. max_parents = params.get('max_parents_per_pair', 3)
  171. input_summary = {"bfp_lists_count": len(bfp_result_lists),
  172. "score_threshold": score_threshold, "max_parents_per_pair": max_parents}
  173. output = enhance_with_parent_docs_grouped(
  174. mgr, bfp_result_lists,
  175. score_threshold=score_threshold, max_parents_per_pair=max_parents)
  176. serialized = {
  177. "enhanced_count": output.get("enhanced_count", 0),
  178. "enhanced_pairs": output.get("enhanced_pairs", 0),
  179. "total_pairs": output.get("total_pairs", 0),
  180. "parent_docs": _serialize_parent_docs(output.get("parent_docs", [])),
  181. "enhanced_results_summary": f"{len(output.get('enhanced_results', []))} 个查询对的结果"
  182. }
  183. return {"status": "success", "step": step_name, "input_summary": input_summary,
  184. "output": serialized, "execution_time": round(time.time() - start_time, 3)}
  185. elif step_name == 'extract_results':
  186. bfp_result_lists = _parse_json_param(content, params, 'bfp_result_lists')
  187. query_pairs = params.get('query_pairs', None)
  188. score_threshold = params.get('score_threshold', 0.5)
  189. if bfp_result_lists is None:
  190. return {"status": "error", "step": step_name, "error": "请提供 bfp_result_lists(在 params 或 content 中传入 JSON)"}
  191. input_summary = {"bfp_lists_count": len(bfp_result_lists),
  192. "has_query_pairs": query_pairs is not None,
  193. "score_threshold": score_threshold}
  194. output = extract_query_pairs_results(bfp_result_lists, query_pairs, score_threshold=score_threshold)
  195. return {"status": "success", "step": step_name, "input_summary": input_summary,
  196. "output": _serialize_results(output), "execution_time": round(time.time() - start_time, 3)}
  197. else:
  198. return {"status": "error", "step": step_name,
  199. "error": f"未知环节: {step_name},可选: query_extract, entity_enhance, multi_stage_recall, hybrid_search, parent_doc_enhance, extract_results"}
  200. except Exception as e:
  201. logger.exception("[rag_debug] step=%s failed", step_name)
  202. return {"status": "error", "step": step_name, "error": str(e),
  203. "execution_time": round(time.time() - start_time, 3)}
  204. def _debug_chain_sync(content: str, params: dict) -> dict:
  205. chain_start = time.time()
  206. steps = {}
  207. # Step 1
  208. t0 = time.time()
  209. try:
  210. query_pairs = query_rewrite_manager.query_extract(content)
  211. steps["query_extract"] = {"status": "success", "execution_time": round(time.time() - t0, 3),
  212. "output": query_pairs, "summary": f"提取到 {len(query_pairs) if query_pairs else 0} 个查询对"}
  213. except Exception as e:
  214. steps["query_extract"] = {"status": "error", "execution_time": round(time.time() - t0, 3), "error": str(e)}
  215. return {"status": "error", "steps": steps, "execution_time": round(time.time() - chain_start, 3),
  216. "error": "query_extract 失败"}
  217. if not query_pairs:
  218. return {"status": "no_results", "steps": steps, "execution_time": round(time.time() - chain_start, 3),
  219. "message": "query_extract 未提取到查询对"}
  220. # Step 2
  221. t0 = time.time()
  222. try:
  223. bfp_result_lists = entity_enhance.entities_enhance_retrieval(query_pairs)
  224. total_bfp = sum(len(r) for r in bfp_result_lists) if bfp_result_lists else 0
  225. steps["entity_enhance"] = {"status": "success", "execution_time": round(time.time() - t0, 3),
  226. "summary": f"召回 {total_bfp} 个BFP结果"}
  227. except Exception as e:
  228. steps["entity_enhance"] = {"status": "error", "execution_time": round(time.time() - t0, 3), "error": str(e)}
  229. return {"status": "error", "steps": steps, "execution_time": round(time.time() - chain_start, 3),
  230. "error": "entity_enhance 失败"}
  231. if not bfp_result_lists:
  232. return {"status": "no_results", "steps": steps, "execution_time": round(time.time() - chain_start, 3),
  233. "message": "entity_enhance 未召回结果"}
  234. # Step 3
  235. t0 = time.time()
  236. try:
  237. mgr = _get_milvus()
  238. score_threshold = params.get('score_threshold', 0.3)
  239. max_parents = params.get('max_parents_per_pair', 3)
  240. enhancement_result = enhance_with_parent_docs_grouped(
  241. mgr, bfp_result_lists,
  242. score_threshold=score_threshold, max_parents_per_pair=max_parents)
  243. enhanced_results = enhancement_result.get('enhanced_results', bfp_result_lists)
  244. steps["parent_doc_enhance"] = {"status": "success", "execution_time": round(time.time() - t0, 3),
  245. "summary": f"增强 {enhancement_result.get('enhanced_pairs', 0)}/{enhancement_result.get('total_pairs', 0)} 个查询对"}
  246. except Exception as e:
  247. steps["parent_doc_enhance"] = {"status": "error", "execution_time": round(time.time() - t0, 3), "error": str(e)}
  248. enhanced_results = bfp_result_lists
  249. # Step 4
  250. t0 = time.time()
  251. try:
  252. extract_threshold = params.get('score_threshold', 0.5)
  253. entity_results = extract_query_pairs_results(enhanced_results, query_pairs, score_threshold=extract_threshold)
  254. steps["extract_results"] = {"status": "success", "execution_time": round(time.time() - t0, 3),
  255. "output": _serialize_results(entity_results),
  256. "summary": f"提取 {len(entity_results) if entity_results else 0} 个高分结果"}
  257. except Exception as e:
  258. steps["extract_results"] = {"status": "error", "execution_time": round(time.time() - t0, 3), "error": str(e)}
  259. return {"status": "success", "steps": steps, "execution_time": round(time.time() - chain_start, 3)}
  260. def _rag_enhanced_check_sync(query_content: str) -> dict:
  261. """完整 RAG 增强检查链路(同步版)"""
  262. trace_id = f"rag_{int(time.time() * 1000)}"
  263. rag_monitor.start_trace(trace_id, metadata={
  264. "content_length": len(query_content),
  265. "stage": "rag_enhanced_check"
  266. })
  267. try:
  268. query_pairs = query_rewrite_manager.query_extract(query_content)
  269. if not query_pairs:
  270. return {"status": "no_results", "trace_id": trace_id, "message": "query_extract 未提取到查询对"}
  271. bfp_result_lists = entity_enhance.entities_enhance_retrieval(query_pairs)
  272. if not bfp_result_lists:
  273. return {"status": "no_results", "trace_id": trace_id, "message": "实体检索未返回结果"}
  274. mgr = _get_milvus()
  275. try:
  276. enhancement_result = enhance_with_parent_docs_grouped(
  277. mgr, bfp_result_lists, score_threshold=0.3, max_parents_per_pair=3)
  278. enhanced_results = enhancement_result['enhanced_results']
  279. except Exception:
  280. enhanced_results = bfp_result_lists
  281. entity_results = extract_query_pairs_results(enhanced_results, query_pairs, score_threshold=0.5) if enhanced_results else []
  282. pipeline_data = {
  283. "trace_id": trace_id,
  284. "stage": "rag_enhanced_check",
  285. "total_execution_time": 0,
  286. "final_result": {
  287. "retrieval_status": "success" if entity_results else "no_results",
  288. "entity_results": entity_results,
  289. "total_entities": len(entity_results) if entity_results else 0
  290. }
  291. }
  292. return {"status": "success", **pipeline_data}
  293. except Exception as e:
  294. logger.exception("[rag_debug] pipeline failed")
  295. return {"status": "error", "trace_id": trace_id, "error": str(e)}
  296. finally:
  297. rag_monitor.end_trace(trace_id)
  298. def _native_rag_check_sync(query_content: str, collection_name: str = "rag_children_hybrid",
  299. hybrid_top_k: int = 20, top_k: int = 5) -> dict:
  300. results = retrieval_manager.multi_stage_recall(
  301. collection_name=collection_name, query_text=query_content,
  302. hybrid_top_k=hybrid_top_k, top_k=top_k)
  303. serialized = _serialize_results(results)
  304. return {"status": "success", "results": serialized, "total_results": len(serialized)}
  305. def _parent_child_rag_check_sync(query_content: str, hybrid_top_k: int = 20, top_k: int = 5,
  306. parent_score_threshold: float = 0.3, max_parents: int = 3) -> dict:
  307. child_results = retrieval_manager.multi_stage_recall(
  308. collection_name="rag_children_hybrid", query_text=query_content,
  309. hybrid_top_k=hybrid_top_k, top_k=top_k)
  310. if not child_results:
  311. return {"status": "no_results", "child_results": [], "parent_results": []}
  312. mgr = _get_milvus()
  313. bfp_formatted = [child_results]
  314. try:
  315. enhancement_result = enhance_with_parent_docs_grouped(
  316. mgr, bfp_formatted, score_threshold=parent_score_threshold, max_parents_per_pair=max_parents)
  317. enhanced_results = enhancement_result.get('enhanced_results', [[]])
  318. parent_docs = enhancement_result.get('parent_docs', [])
  319. except Exception:
  320. enhanced_results = bfp_formatted
  321. parent_docs = []
  322. final_results = enhanced_results[0] if enhanced_results else child_results
  323. return {
  324. "status": "success",
  325. "child_results": _serialize_results(child_results),
  326. "parent_documents": _serialize_parent_docs(parent_docs),
  327. "enhanced_results": _serialize_results(final_results),
  328. "total_children": len(child_results),
  329. "total_parents": len(parent_docs),
  330. "total_enhanced": len(final_results)
  331. }
  332. async def _professional_review_async(review_content: str, check_type: str = "both") -> dict:
  333. """专业性审查完整测试(异步版)"""
  334. engine = _get_ai_review_engine()
  335. trace_id = f"professional_review_{int(time.time() * 1000)}"
  336. rag_monitor.start_trace(trace_id, metadata={
  337. "content_length": len(review_content),
  338. "check_type": check_type,
  339. "stage": "professional_review_test"
  340. })
  341. try:
  342. # Step 1: query_extract
  343. query_pairs = await asyncio.to_thread(query_rewrite_manager.query_extract, review_content)
  344. if not query_pairs:
  345. return {"status": "error", "trace_id": trace_id, "error": "查询提取失败"}
  346. # Step 2: entity enhance
  347. bfp_result_lists = await asyncio.to_thread(entity_enhance.entities_enhance_retrieval, query_pairs)
  348. if not bfp_result_lists:
  349. return {"status": "no_results", "trace_id": trace_id, "message": "未获取到有效的RAG召回结果"}
  350. # Step 3: parent doc enhancement
  351. mgr = _get_milvus()
  352. try:
  353. enhancement_result = await asyncio.to_thread(
  354. enhance_with_parent_docs_grouped, mgr, bfp_result_lists,
  355. **{"score_threshold": 0.3, "max_parents_per_pair": 3})
  356. enhanced_results = enhancement_result['enhanced_results']
  357. except Exception:
  358. enhanced_results = bfp_result_lists
  359. # Step 4: extract results
  360. entity_results = extract_query_pairs_results(enhanced_results, query_pairs, score_threshold=0.5) if enhanced_results else []
  361. if not entity_results:
  362. return {"status": "no_results", "trace_id": trace_id, "message": "没有结果通过阈值过滤"}
  363. # Step 5: AI review for each entity
  364. review_results = []
  365. for idx, entity_result in enumerate(entity_results):
  366. entity = entity_result.get('entity', '')
  367. combined_query = entity_result.get('combined_query', '')
  368. text_content = entity_result.get('text_content', '')
  369. file_name = entity_result.get('file_name', '')
  370. trace_id_idx = f"{trace_id}_entity_{idx}"
  371. entity_review = {
  372. "entity": entity,
  373. "combined_query": combined_query,
  374. "reference_source": file_name,
  375. "rag_score": entity_result.get('final_score', 0),
  376. "non_parameter_result": None,
  377. "parameter_result": None
  378. }
  379. if check_type in ["non_parameter", "both"]:
  380. try:
  381. non_param_result = await engine.check_non_parameter_compliance(
  382. trace_id_idx=trace_id_idx,
  383. review_content=review_content,
  384. review_references=text_content,
  385. reference_source=file_name,
  386. state={"callback_task_id": "test_task_id", "progress_manager": None},
  387. stage_name="专业性审查测试",
  388. entity_query=combined_query
  389. )
  390. entity_review["non_parameter_result"] = {
  391. 'success': non_param_result.success,
  392. 'details': non_param_result.details,
  393. 'error_message': non_param_result.error_message,
  394. 'execution_time': non_param_result.execution_time
  395. }
  396. except Exception as e:
  397. entity_review["non_parameter_result"] = {"error": str(e)}
  398. if check_type in ["parameter", "both"]:
  399. try:
  400. param_result = await engine.check_parameter_compliance(
  401. trace_id_idx=trace_id_idx,
  402. review_content=review_content,
  403. review_references=text_content,
  404. reference_source=file_name,
  405. state={"callback_task_id": "test_task_id", "progress_manager": None},
  406. stage_name="专业性审查测试",
  407. entity_query=combined_query
  408. )
  409. entity_review["parameter_result"] = {
  410. 'success': param_result.success,
  411. 'details': param_result.details,
  412. 'error_message': param_result.error_message,
  413. 'execution_time': param_result.execution_time
  414. }
  415. except Exception as e:
  416. entity_review["parameter_result"] = {"error": str(e)}
  417. review_results.append(entity_review)
  418. return {
  419. "status": "success",
  420. "trace_id": trace_id,
  421. "check_type": check_type,
  422. "rag_summary": {
  423. "query_pairs_count": len(query_pairs),
  424. "entity_results_count": len(entity_results),
  425. "query_pairs": query_pairs,
  426. "entity_results": entity_results
  427. },
  428. "review_results": review_results,
  429. "total_entities_reviewed": len(review_results)
  430. }
  431. except Exception as e:
  432. logger.exception("[rag_debug] professional_review failed")
  433. return {"status": "error", "trace_id": trace_id, "error": str(e)}
  434. finally:
  435. rag_monitor.end_trace(trace_id)
  436. # ============ 路由注册 ============
  437. def register_routes(router: APIRouter):
  438. @router.post("/rag/step")
  439. async def debug_step_endpoint(request: DebugStepRequest):
  440. """执行单个 RAG 环节调试"""
  441. step_name = request.step
  442. if step_name not in ('query_extract', 'entity_enhance', 'multi_stage_recall',
  443. 'hybrid_search', 'parent_doc_enhance', 'extract_results'):
  444. raise HTTPException(status_code=400, detail=f"未知环节: {step_name}")
  445. result = await asyncio.to_thread(_debug_step_sync, step_name, request.content, request.params or {})
  446. return result
  447. @router.post("/rag/chain")
  448. async def debug_chain_endpoint(request: DebugChainRequest):
  449. """链式执行: query_extract → entity_enhance → parent_doc_enhance → extract_results"""
  450. if not request.content:
  451. raise HTTPException(status_code=400, detail="请提供 content 参数")
  452. result = await asyncio.to_thread(_debug_chain_sync, request.content, request.params or {})
  453. return result
  454. @router.post("/rag/pipeline")
  455. async def rag_pipeline_endpoint(request: RagRequest):
  456. """完整 RAG 增强检查链路"""
  457. if not request.content:
  458. raise HTTPException(status_code=400, detail="请提供 content 参数")
  459. result = await asyncio.to_thread(_rag_enhanced_check_sync, request.content)
  460. return result
  461. @router.post("/rag/native")
  462. async def native_rag_endpoint(request: RagRequest):
  463. """Native RAG — 基础召回 + 重排序"""
  464. if not request.content:
  465. raise HTTPException(status_code=400, detail="请提供 content 参数")
  466. result = await asyncio.to_thread(
  467. _native_rag_check_sync, request.content,
  468. request.collection_name or "rag_children_hybrid",
  469. request.hybrid_top_k or 20, request.top_k or 5
  470. )
  471. return result
  472. @router.post("/rag/parent-child")
  473. async def parent_child_rag_endpoint(request: RagRequest):
  474. """父子文档模式 — 子检索 → 父文档增强"""
  475. if not request.content:
  476. raise HTTPException(status_code=400, detail="请提供 content 参数")
  477. result = await asyncio.to_thread(
  478. _parent_child_rag_check_sync, request.content,
  479. request.hybrid_top_k or 20, request.top_k or 5,
  480. request.parent_score_threshold or 0.3, request.max_parents or 3
  481. )
  482. return result
  483. @router.post("/rag/professional-review")
  484. async def professional_review_endpoint(request: ProfessionalReviewRequest):
  485. """专业性审查完整测试(RAG + AI 审查)"""
  486. if not request.content:
  487. raise HTTPException(status_code=400, detail="请提供 content 参数")
  488. if request.check_type not in ('non_parameter', 'parameter', 'both'):
  489. raise HTTPException(status_code=400, detail="check_type 必须是 non_parameter, parameter 或 both")
  490. result = await _professional_review_async(request.content, request.check_type or "both")
  491. return result
  492. @router.post("/rag/init")
  493. async def init_milvus_endpoint():
  494. """初始化 Milvus"""
  495. try:
  496. _get_milvus()
  497. return {"status": "ok", "message": "Milvus 初始化成功"}
  498. except Exception as e:
  499. raise HTTPException(status_code=500, detail=str(e))
  500. @router.get("/rag/data")
  501. async def get_rag_data():
  502. """获取最新的 pipeline 数据"""
  503. data_path = os.path.join(project_root, "temp", "rag_pipeline_server", "rag_pipeline_data.json")
  504. if os.path.exists(data_path):
  505. with open(data_path, 'r', encoding='utf-8') as f:
  506. return json.load(f)
  507. raise HTTPException(status_code=404, detail="数据文件不存在")