#!/usr/bin/env python # -*- coding: utf-8 -*- """ RAG链路独立测试工具 用于快速调通和验证参数合规性检查的RAG检索+LLM审查功能 核心功能: 1. rag_enhanced_check() - 完整的RAG检索逻辑 2. check_parameter_compliance() - 参数合规性检查(与原链路完全一致) 使用方法: python test_rag_pipeline.py """ import sys import os import json import time import asyncio from typing import Dict, Any sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from core.construction_review.component.infrastructure.milvus import MilvusConfig, MilvusManager from core.construction_review.component.infrastructure.parent_tool import ( enhance_with_parent_docs, extract_first_result ) from core.construction_review.component.reviewers.base_reviewer import BaseReviewer, ReviewResult from foundation.ai.rag.retrieval.entities_enhance import entity_enhance from foundation.ai.rag.retrieval.query_rewrite import query_rewrite_manager from foundation.ai.agent.generate.model_generate import generate_model_client from core.construction_review.component.reviewers.utils.prompt_loader import prompt_loader from foundation.observability.logger.loggering import review_logger as logger # ============================================================================ # 简化的BaseReviewer类 - 用于调用LLM审查 # ============================================================================ class SimpleReviewer(BaseReviewer): """ 简化的审查器 - 继承BaseReviewer,用于调用LLM审查 """ def __init__(self): """初始化简化的审查器""" super().__init__() self.model_client = generate_model_client self.prompt_loader = prompt_loader # 全局审查器实例 simple_reviewer = SimpleReviewer() # ============================================================================ # 核心RAG链路函数 # ============================================================================ def rag_enhanced_check(milvus_manager, unit_content: dict) -> dict: """ RAG增强检查 - 完整链路 流程: 1. 查询提取 (query_rewrite_manager.query_extract) 2. 实体增强检索 (entity_enhance.entities_enhance_retrieval) 3. 父文档增强 (enhance_with_parent_docs) 4. 提取第一个结果 (extract_first_result) Args: milvus_manager: MilvusManager实例 unit_content: 包含content字段的字典,格式: {"content": "待检索的文本内容"} Returns: dict: RAG检索结果,包含: - vector_search: 向量检索结果列表 - retrieval_status: 检索状态 - file_name: 参考文件名 - text_content: 参考文本内容 - metadata: 元数据信息 """ # 创建数据流跟踪字典 pipeline_data = { "stage": "rag_enhanced_check", "timestamp": time.time(), "steps": {} } query_content = unit_content['content'] logger.info(f"[RAG增强] 开始处理, 内容长度: {len(query_content)}") # Step 1: 查询提取 logger.info("=" * 80) logger.info("Step 1: 查询提取") logger.info("=" * 80) logger.info(f"开始查询提取, 输入内容长度: {len(query_content)}") logger.info(f"输入内容预览: {query_content[:200]}...") # 执行查询提取 query_pairs = query_rewrite_manager.query_extract(query_content) logger.info(f"[RAG增强] 提取到 {len(query_pairs)} 个查询对") # 打印查询对详情 for idx, query_pair in enumerate(query_pairs): logger.info(f" 查询对 {idx+1}: {query_pair}") # 保存Step 1的输入输出 pipeline_data["steps"]["1_query_extract"] = { "input": { "content_length": len(query_content), "content_full": query_content, "content_preview": query_content[:200] }, "output": { "query_pairs_count": len(query_pairs), "query_pairs": [str(qp) for qp in query_pairs], # 转为字符串列表便于查看 "extraction_timestamp": time.time() } } # Step 2: 实体增强检索 logger.info("=" * 80) logger.info("Step 2: 实体增强检索") logger.info("=" * 80) logger.info(f"开始实体增强检索, 输入查询对数量: {len(query_pairs)}") # 保存输入 entity_enhance_input = { "query_pairs": [str(qp) for qp in query_pairs], "query_pairs_count": len(query_pairs) } # 详细记录每个查询对的处理过程 entity_enhance_process_details = [] # 手动展开实体增强检索的每个步骤,便于记录数据流 import asyncio def run_async(coro): """在合适的环境中运行异步函数""" try: loop = asyncio.get_running_loop() import concurrent.futures with concurrent.futures.ThreadPoolExecutor() as executor: future = executor.submit(asyncio.run, coro) return future.result() except RuntimeError: return asyncio.run(coro) # 导入retrieval_manager from foundation.ai.rag.retrieval.retrieval import retrieval_manager bfp_result_lists = [] # 遍历每个查询对进行处理 for idx, query_pair in enumerate(query_pairs): logger.info(f"\n{'='*60}") logger.info(f"处理查询对 {idx+1}/{len(query_pairs)}") logger.info(f"{'='*60}") # 提取查询对的各个字段 entity = query_pair['entity'] search_keywords = query_pair['search_keywords'] background = query_pair['background'] logger.info(f" 实体(entity): {entity}") logger.info(f" 搜索关键词(search_keywords): {search_keywords}") logger.info(f" 背景(background): {background}") # 记录当前查询对的输入 current_query_detail = { "index": idx + 1, "input": { "entity": entity, "search_keywords": search_keywords, "background": background }, "steps": {} } # Step 2.1: 实体召回 (entity_recall) logger.info(f" Step 2.1: 实体召回 (recall_top_k=5, max_results=5)") entity_list = run_async(retrieval_manager.entity_recall( entity, search_keywords, recall_top_k=5, max_results=5 )) logger.info(f" ✅ 实体召回完成, 召回实体数量: {len(entity_list) if entity_list else 0}") # 记录实体召回结果 current_query_detail["steps"]["2_1_entity_recall"] = { "input": { "entity": entity, "search_keywords": search_keywords, "recall_top_k": 5, "max_results": 5 }, "output": { "entity_list": entity_list, "entity_count": len(entity_list) if entity_list else 0 } } # Step 2.2: BFP召回 (async_bfp_recall) logger.info(f" Step 2.2: BFP召回 (top_k=3)") bfp_result = run_async(retrieval_manager.async_bfp_recall( entity_list, background, top_k=3 )) logger.info(f" ✅ BFP召回完成, BFP结果数量: {len(bfp_result) if bfp_result else 0}") logger.info(f" bfp_result: {bfp_result}") # 记录BFP召回结果 current_query_detail["steps"]["2_2_bfp_recall"] = { "input": { "entity_list": entity_list, "background": background, "top_k": 3 }, "output": { "bfp_result": bfp_result, "bfp_result_count": len(bfp_result) if bfp_result else 0 } } bfp_result_lists.append(bfp_result) entity_enhance_process_details.append(current_query_detail) logger.info(f"✅ 查询对 {idx+1} 处理完成") logger.info(f"\n{'='*80}") logger.info(f"实体增强检索全部完成") logger.info(f"总查询对数: {len(query_pairs)}") logger.info(f"总BFP结果数: {len(bfp_result_lists)}") logger.info(f"{'='*80}") # 保存Step 2的详细输出 pipeline_data["steps"]["2_entity_enhance_retrieval"] = { "input": entity_enhance_input, "output": { "results_count": len(bfp_result_lists), "results": bfp_result_lists, "process_details": entity_enhance_process_details # 每个查询对的详细处理过程 }, "timestamp": time.time() } # Step 3: 检查检索结果 if not bfp_result_lists: logger.warning("[RAG增强] 实体检索未返回结果") # 保存最终数据流 os.makedirs(r"temp\entity_bfp_recall", exist_ok=True) with open(rf"temp\entity_bfp_recall\rag_pipeline_data.json", "w", encoding='utf-8') as f: json.dump(pipeline_data, f, ensure_ascii=False, indent=4) return { 'vector_search': [], 'retrieval_status': 'no_results', 'file_name': '', 'text_content': '', 'metadata': {} } logger.info(f"[RAG增强] 实体检索返回 {len(bfp_result_lists)} 个结果") # Step 4: 父文档增强 (使用独立工具函数) logger.info("=" * 80) logger.info("Step 3: 父文档增强") logger.info("=" * 80) try: enhancement_result = enhance_with_parent_docs(milvus_manager, bfp_result_lists) enhanced_results = enhancement_result['enhanced_results'] enhanced_count = enhancement_result['enhanced_count'] parent_docs = enhancement_result['parent_docs'] # 保存Step 3输出 pipeline_data["steps"]["3_parent_doc_enhancement"] = { "input": { "bfp_results_count": len(bfp_result_lists) }, "output": { "enhanced_count": enhanced_count, "parent_docs_count": len(parent_docs), "parent_docs": parent_docs, "enhanced_results": enhanced_results } } # 保存增强后的结果 os.makedirs(r"temp\entity_bfp_recall", exist_ok=True) with open(rf"temp\entity_bfp_recall\enhance_with_parent_docs.json", "w", encoding='utf-8') as f: json.dump(enhanced_results, f, ensure_ascii=False, indent=4) logger.info(f"[RAG增强] 成功增强 {enhanced_count} 个结果") logger.info(f"[RAG增强] 使用了 {len(parent_docs)} 个父文档") # 打印父文档信息 for idx, parent_doc in enumerate(parent_docs): logger.info(f" 父文档 {idx+1}: {parent_doc.get('file_name', 'unknown')}") except Exception as e: logger.error(f"[RAG增强] 父文档增强失败: {e}", exc_info=True) # 保存错误信息 pipeline_data["steps"]["3_parent_doc_enhancement"] = { "input": { "bfp_results_count": len(bfp_result_lists) }, "output": { "error": str(e), "error_type": type(e).__name__ } } # 失败时使用原始结果 enhanced_results = bfp_result_lists parent_docs = [] # Step 5: 提取第一个结果返回 (使用增强后的结果) logger.info("=" * 80) logger.info("Step 4: 提取第一个结果") logger.info("=" * 80) final_result = extract_first_result(enhanced_results) # 保存Step 4输出 pipeline_data["steps"]["4_extract_first_result"] = { "input": { "enhanced_results_count": len(enhanced_results) }, "output": { "final_result": final_result } } # 保存最终结果用于调试 with open(rf"temp\entity_bfp_recall\extract_first_result.json", "w", encoding='utf-8') as f: json.dump(final_result, f, ensure_ascii=False, indent=4) # 保存完整数据流 pipeline_data["final_result"] = final_result os.makedirs(r"temp\entity_bfp_recall", exist_ok=True) with open(rf"temp\entity_bfp_recall\rag_pipeline_data.json", "w", encoding='utf-8') as f: json.dump(pipeline_data, f, ensure_ascii=False, indent=4) logger.info(f"[RAG增强] 最终提取结果文件名: {final_result.get('file_name', '无')}") logger.info(f"[RAG增强] 最终提取结果内容长度: {len(final_result.get('text_content', ''))}") logger.info(f"[RAG增强] 完整数据流已保存到: temp/entity_bfp_recall/rag_pipeline_data.json") return final_result # ============================================================================ # 参数合规性检查函数 (与原链路完全一致) # ============================================================================ async def check_parameter_compliance(trace_id_idx: str, review_content: str, review_references: str, reference_source: str, review_location_label: str, state: str, stage_name: str) -> Dict[str, Any]: """ 参数合规性检查 - 实体概念/工程术语知识库 (与原链路完全一致的方法签名和实现) Args: trace_id_idx: 追踪ID索引 review_content: 审查内容 review_references: 审查参考信息 reference_source: 参考来源 review_location_label: 审查位置标签 state: 状态字典 stage_name: 阶段名称 Returns: Dict[str, Any]: 参数合规性检查结果 """ # 从原链路导入Stage枚举 from core.construction_review.component.ai_review_engine import Stage reviewer_type = Stage.TECHNICAL.value['reviewer_type'] prompt_name = Stage.TECHNICAL.value['parameter'] trace_id = prompt_name + trace_id_idx # 直接调用原链路的review方法 return await simple_reviewer.review("parameter_compliance_check", trace_id, reviewer_type, prompt_name, review_content, review_references, reference_source, review_location_label, state, stage_name, timeout=45) # ============================================================================ # 主测试函数 # ============================================================================ async def main(): """ 主测试函数 - 测试参数合规性检查的完整流程 流程: 1. 初始化Milvus Manager 2. 准备测试内容 3. 调用RAG获取参考信息 4. 调用参数合规性检查(与原链路一致) 5. 保存完整数据流 """ print("\n" + "=" * 80) print("RAG链路独立测试工具 - 参数合规性检查".center(80)) print("=" * 80 + "\n") # 初始化Milvus Manager print("📌 初始化Milvus Manager...") logger.info("初始化Milvus Manager...") try: milvus_manager = MilvusManager(MilvusConfig()) print("✅ Milvus Manager 初始化成功\n") except Exception as e: print(f"❌ Milvus Manager 初始化失败: {e}") logger.error(f"Milvus Manager 初始化失败: {e}", exc_info=True) return # 测试内容 test_content = """主要部件说明 1、主梁总成 主梁总成由主梁和导梁构成。主梁单节长12m,共7节,每节重10.87t,主梁为主要承载受力构件,其上弦杆上方设有轨道供纵移桁车走行,实现预制梁的纵向移动;下弦设有反滚轮行走轨道,作为导梁纵移、前中支腿移动纵行轨道。导梁长18m,主要是为降低过孔挠度和承受中支腿移动荷载,起安全引导、辅助过孔作用。主梁、导梁为三角桁架构件单元,采用销轴连接,前、后端各设置横联构架。 图4-1 主梁总成图 注意事项: (1)更换上、下弦销轴时,应优先向设备供应方购买符合要求的备件。自行更换时,材料性能必须优于设计零件性能,并按规定进行热处理,否则可能造成人员、设备事故。 (2)销轴不得弯曲受力,不得用销轴作为锤砸工具,不得任意放置及焊接""" unit_content = {"content": test_content} print(f"📝 测试内容长度: {len(test_content)} 字符") print(f"📝 测试内容预览:\n{test_content[:200]}...\n") # 创建数据流跟踪字典 pipeline_data = { "stage": "parameter_compliance_check_full_pipeline", "timestamp": time.time(), "steps": {} } # Step 1: RAG增强检索 print("=" * 80) print("【Step 1】RAG增强检索".center(80)) print("=" * 80) logger.info("=" * 80) logger.info("Step 1: RAG增强检索") logger.info("=" * 80) start_time = time.time() rag_result = rag_enhanced_check(milvus_manager, unit_content) review_references = rag_result.get('text_content', '') reference_source = rag_result.get('file_name', '') # 保存Step 1数据 pipeline_data["steps"]["1_rag_retrieval"] = { "input": { "unit_content": unit_content }, "output": { "rag_result": rag_result, "review_references_length": len(review_references), "reference_source": reference_source }, "execution_time": time.time() - start_time } if not review_references: logger.warning("RAG检索未返回参考信息,将继续使用空参考进行审查") print("⚠️ RAG检索未返回参考信息\n") else: print(f"✅ RAG检索成功") print(f" 参考来源: {reference_source}") print(f" 参考内容长度: {len(review_references)} 字符\n") # Step 2: 调用参数合规性检查 (使用原链路的方法) print("=" * 80) print("【Step 2】参数合规性检查 (LLM审查)".center(80)) print("=" * 80) logger.info("=" * 80) logger.info("Step 2: 参数合规性检查") logger.info("=" * 80) trace_id_idx = "_test_001" review_location_label = "测试文档-第1章" state = None stage_name = "test_stage" logger.info(f"开始调用参数合规性检查") logger.info(f" - trace_id_idx: {trace_id_idx}") logger.info(f" - review_content长度: {len(test_content)}") logger.info(f" - review_references长度: {len(review_references)}") logger.info(f" - reference_source: {reference_source}") # 保存Step 2输入 pipeline_data["steps"]["2_parameter_compliance_check"] = { "input": { "trace_id_idx": trace_id_idx, "review_content_length": len(test_content), "review_content_preview": test_content[:200], "review_references_length": len(review_references), "review_references_preview": review_references[:200] if review_references else "", "reference_source": reference_source, "review_location_label": review_location_label, "stage_name": stage_name }, "output": {} } start_time = time.time() try: # 调用与原链路完全一致的方法 result = await check_parameter_compliance( trace_id_idx=trace_id_idx, review_content=test_content, review_references=review_references, reference_source=reference_source, review_location_label=review_location_label, state=state, stage_name=stage_name ) elapsed_time = time.time() - start_time # 保存Step 2输出 pipeline_data["steps"]["2_parameter_compliance_check"]["output"] = { "success": result.success, "execution_time": result.execution_time, "error_message": result.error_message, "details": result.details } # 保存完整数据流 pipeline_data["final_result"] = { "success": result.success, "execution_time": result.execution_time, "error_message": result.error_message, "details": result.details } os.makedirs(r"temp\entity_bfp_recall", exist_ok=True) with open(rf"temp\entity_bfp_recall\parameter_compliance_full_pipeline.json", "w", encoding='utf-8') as f: json.dump(pipeline_data, f, ensure_ascii=False, indent=4) logger.info(f"✅ 参数合规性检查完成, 总耗时: {elapsed_time:.2f}秒") logger.info(f"📁 完整数据流已保存到: temp/entity_bfp_recall/parameter_compliance_full_pipeline.json") except Exception as e: error_msg = f"参数合规性检查失败: {str(e)}" logger.error(error_msg, exc_info=True) # 保存错误信息 pipeline_data["steps"]["2_parameter_compliance_check"]["output"] = { "error": error_msg, "error_type": type(e).__name__, "traceback": str(e) } pipeline_data["error"] = { "error_message": error_msg, "error_type": type(e).__name__ } os.makedirs(r"temp\entity_bfp_recall", exist_ok=True) with open(rf"temp\entity_bfp_recall\parameter_compliance_full_pipeline.json", "w", encoding='utf-8') as f: json.dump(pipeline_data, f, ensure_ascii=False, indent=4) print(f"❌ 参数合规性检查失败: {error_msg}\n") return # 输出测试结果 print("\n" + "=" * 80) print("测试结果".center(80)) print("=" * 80) status_icon = "✅" if result.success else "❌" print(f"\n{status_icon} 参数合规性检查") print(f" 执行时间: {result.execution_time:.2f}秒") if result.success: print(f" 审查成功!") print(f" 详细信息: {result.details.get('name', 'N/A')}") # 如果有RAG参考信息,打印出来 if 'rag_reference_source' in result.details: print(f"\n 📚 RAG参考信息:") print(f" 参考来源: {result.details['rag_reference_source']}") print(f" 参考内容长度: {len(result.details.get('rag_review_references', ''))} 字符") # 打印审查响应(截取前500字符) response = result.details.get('response', '') if response: print(f"\n 📋 审查响应 (前500字符):") print(f" {response[:500]}...") else: print(f" 错误信息: {result.error_message}") # 输出文件位置 print("\n" + "=" * 80) print("详细结果已保存到:".center(80)) print(" 📁 temp/entity_bfp_recall/rag_pipeline_data.json - RAG检索完整数据流") print(" 📁 temp/entity_bfp_recall/enhance_with_parent_docs.json - 父文档增强结果") print(" 📁 temp/entity_bfp_recall/extract_first_result.json - 最终提取结果") print(" 📁 temp/entity_bfp_recall/parameter_compliance_full_pipeline.json - 参数检查完整数据流") print("=" * 80 + "\n") print("✅ 测试完成!") # 保存测试结果摘要 os.makedirs(r"temp\entity_bfp_recall", exist_ok=True) test_summary = { "test_type": "parameter_compliance", "check_display_name": "参数合规性检查", "timestamp": time.time(), "result": { 'success': result.success, 'execution_time': result.execution_time, 'error_message': result.error_message, 'details_summary': { 'name': result.details.get('name'), 'has_rag_reference': 'rag_reference_source' in result.details, 'response_length': len(result.details.get('response', '')), 'response_preview': result.details.get('response', '')[:200] } } } with open(rf"temp\entity_bfp_recall\test_summary.json", "w", encoding='utf-8') as f: json.dump(test_summary, f, ensure_ascii=False, indent=4) return result if __name__ == "__main__": # 运行异步主函数 asyncio.run(main())