from __future__ import annotations import asyncio from pathlib import Path from typing import Any, AsyncIterator, Dict, List, Optional from llm_pipeline.core.config import YamlConfigProvider from llm_pipeline.core.pipeline import LLMPipeline from llm_pipeline.interfaces import DataLoader, ResultSaver from llm_pipeline.entity_extract_v1.dataloaders import EntityExtractV1JsonChunksLoader from llm_pipeline.entity_extract_v1.prompting import ( EntityExtractV1JsonResponseParser, EntityExtractV1PromptBuilder, ) from llm_pipeline.entity_extract_v1.factory import build_llm_client as build_extract_llm_client from llm_pipeline.entity_extract_eval_v1.prompting import ( EntityEvalV1JsonResponseParser, EntityEvalV1PromptBuilder, ) from llm_pipeline.entity_extract_eval_v1.factory import build_llm_client as build_eval_llm_client from llm_pipeline.rag_retrieval_eval_v1.factory import ( build_rag_eval_pipeline_for_entities_items, build_rag_eval_pipeline_for_qa_json, ) from llm_pipeline.entity_extract_v1.factory import build_pipeline_for_csv, build_pipeline_for_json from llm_pipeline.entity_extract_eval_v1.factory import build_eval_pipeline_for_json async def run_entity_extract_v1_with_json( input_json: str, output_json: str = "output_from_json.json", ) -> None: """使用 entity_extract_v1 版本:JSON → JSON 处理。""" pipeline, _ = build_pipeline_for_json(input_json=input_json, output_json=output_json) await pipeline.run() async def run_entity_extract_v1_with_csv( input_csv: str = "input.csv", output_csv: str = "output.csv", ) -> None: """使用 entity_extract_v1 版本:CSV → CSV 处理。""" pipeline, _ = build_pipeline_for_csv(input_csv=input_csv, output_csv=output_csv) await pipeline.run() async def run_entity_eval_v1_with_json( input_json: str, output_json: str = "output_from_json_eval.json", ) -> None: """使用 entity_extract_eval_v1 版本:对抽取结果做专业性评估与过滤。""" pipeline, _ = build_eval_pipeline_for_json( input_json=input_json, output_json=output_json, ) await pipeline.run() async def run_full_entity_extract_and_eval() -> None: """一键运行:先抽取实体,再对结果进行评估过滤。""" raw_input = ( "44_四川公路桥梁建设集团有限公司镇巴(川陕界)至广安高速公路通广段C合同段C4项目经理部_完整结果_20251212_155323.json" ) first_output = "output_from_json.json" final_output = "output_from_json_eval.json" # 第一步:实体抽取 await run_entity_extract_v1_with_json(input_json=raw_input, output_json=first_output) # 第二步:专业性评估与过滤 await run_entity_eval_v1_with_json(input_json=first_output, output_json=final_output) async def run_rag_retrieval_eval_with_qa_json( input_json: str, output_csv: str = "rag_eval_results.csv", collection: str = "first_bfp_collection_test", hybrid_top_k: int = 20, final_top_k: int = 5, ) -> None: """ 使用 rag_retrieval_eval_v1 版本: - 输入:单个包含 qa_pairs 的 JSON(与 batch_rag_eval_from_qa.py 兼容); - 过程:对每个实体 name 进行检索召回(multi_stage_recall),并调用 LLM 做命中率评估; - 输出:汇总结果写入 CSV,便于统计分析。 """ pipeline, _ = build_rag_eval_pipeline_for_qa_json( input_json=input_json, output_csv=output_csv, collection=collection, hybrid_top_k=hybrid_top_k, final_top_k=final_top_k, ) await pipeline.run() class InMemoryListSaver(ResultSaver): """将流水线结果保存在内存列表中(不落地文件)。""" def __init__(self) -> None: self.items: List[Dict[str, Any]] = [] async def save(self, item: Dict[str, Any], result: Dict[str, Any]) -> None: self.items.append({**item, **result}) class InMemoryDataLoader(DataLoader): """从内存列表提供数据的 DataLoader。""" def __init__(self, items: List[Dict[str, Any]]) -> None: self._items = items async def load_items(self) -> AsyncIterator[Dict[str, Any]]: for it in self._items: yield it def get_total(self) -> Optional[int]: return len(self._items) class InMemoryEntityExtractSaver(ResultSaver): """对齐 entity_extract_v1 的 JSON 输出结构,但保存在内存。""" def __init__(self) -> None: self.items: List[Dict[str, Any]] = [] async def save(self, item: Dict[str, Any], result: Dict[str, Any]) -> None: merged = {**item, **result} simplified = { "file_name": merged.get("file_name"), "chunk_id": merged.get("chunk_id"), "section_label": merged.get("section_label"), "text": merged.get("text"), "entity_extract_result": merged.get("entity_extract_result"), } self.items.append(simplified) class InMemoryEvalFilteredSaver(ResultSaver): """对齐 entity_extract_eval_v1 的过滤逻辑,但保存在内存。""" def __init__(self) -> None: self.items: List[Dict[str, Any]] = [] async def save(self, item: Dict[str, Any], result: Dict[str, Any]) -> None: merged = {**item, **result} entities_obj = merged.get("entity_extract_result") or {} entities = entities_obj.get("entities") if isinstance(entities_obj, dict) else None if not entities or not isinstance(entities, list): return self.items.append( { "file_name": merged.get("file_name"), "chunk_id": merged.get("chunk_id"), "section_label": merged.get("section_label"), "text": merged.get("text"), "entity_extract_result": entities_obj, } ) async def run_full_extract_eval_and_rag_eval_in_memory( input_json: str, output_csv: str = "rag_eval_results.csv", collection: str = "first_bfp_collection_test", hybrid_top_k: int = 20, final_top_k: int = 5, ) -> None: """ 全流程(不依赖中间文件): 1) entity_extract_v1:从 input_json(chunks) 抽取实体概念+背景 2) entity_extract_eval_v1:专业性评估与过滤 3) rag_retrieval_eval_v1:用过滤后的实体(name+背景/证据拼 query)做检索召回 + 命中率评估,输出 CSV """ def _iter_input_json_files(path_str: str) -> List[Path]: p = Path(path_str) if not p.exists(): raise FileNotFoundError(f"输入路径不存在: {p}") if p.is_file(): return [p] if p.is_dir(): # 目录:递归找 json,固定排序保证可复现 return sorted(p.rglob("*.json"), key=lambda x: str(x)) return [] input_files = _iter_input_json_files(input_json) if not input_files: print(f"[INFO] 未找到可处理的 JSON 文件: {input_json}") return all_filtered_items: List[Dict[str, Any]] = [] # === Stage 1 + 2: per-file extract + eval filter (in-memory) === extract_service = Path(__file__).parent / "llm_pipeline" / "entity_extract_v1" / "service.yaml" extract_cfg = YamlConfigProvider(service_path=extract_service) extract_client = build_extract_llm_client(extract_cfg) extract_prompt = EntityExtractV1PromptBuilder(cfg_provider=extract_cfg) extract_parser = EntityExtractV1JsonResponseParser(output_field="entity_extract_result") eval_service = Path(__file__).parent / "llm_pipeline" / "entity_extract_eval_v1" / "service.yaml" eval_cfg = YamlConfigProvider(service_path=eval_service) eval_client = build_eval_llm_client(eval_cfg) eval_prompt = EntityEvalV1PromptBuilder(cfg_provider=eval_cfg) eval_parser = EntityEvalV1JsonResponseParser(output_field="entity_extract_result") for fp in input_files: # === Stage 1: entity_extract_v1 (in-memory) === extract_loader = EntityExtractV1JsonChunksLoader(str(fp)) extract_saver = InMemoryEntityExtractSaver() extract_pipeline = LLMPipeline( llm_client=extract_client, config_provider=extract_cfg, data_loader=extract_loader, prompt_builder=extract_prompt, response_parser=extract_parser, result_saver=extract_saver, ) await extract_pipeline.run() extracted_items = extract_saver.items if not extracted_items: print(f"[INFO] 跳过(抽取阶段无输出): {fp}") continue # === Stage 2: entity_extract_eval_v1 (in-memory) === eval_loader = InMemoryDataLoader(extracted_items) eval_saver = InMemoryEvalFilteredSaver() eval_pipeline = LLMPipeline( llm_client=eval_client, config_provider=eval_cfg, data_loader=eval_loader, prompt_builder=eval_prompt, response_parser=eval_parser, result_saver=eval_saver, ) await eval_pipeline.run() filtered_items = eval_saver.items if not filtered_items: print(f"[INFO] 跳过(评估过滤后无有效实体): {fp}") continue all_filtered_items.extend(filtered_items) # === Stage 3: rag_retrieval_eval_v1 (entities -> retrieval -> hit eval) === if not all_filtered_items: print("[INFO] 全部输入处理完成,但未产生任何可用于 RAG 评估的实体。") return rag_pipeline, _ = build_rag_eval_pipeline_for_entities_items( items=all_filtered_items, # items=extracted_items, output_csv=output_csv, collection=collection, hybrid_top_k=hybrid_top_k, final_top_k=final_top_k, ) await rag_pipeline.run() if __name__ == "__main__": # 默认执行“抽取 → 专业评估过滤 → 检索召回 → 命中率评估(CSV)”全流程(内存承接,不依赖中间文件) asyncio.run( run_full_extract_eval_and_rag_eval_in_memory( input_json="./data", output_csv="rag_eval_results.csv", collection="first_bfp_collection_test", hybrid_top_k=20, final_top_k=5, ) )