| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272 |
- from __future__ import annotations
- import asyncio
- from pathlib import Path
- from typing import Any, AsyncIterator, Dict, List, Optional
- from llm_pipeline.core.config import YamlConfigProvider
- from llm_pipeline.core.pipeline import LLMPipeline
- from llm_pipeline.interfaces import DataLoader, ResultSaver
- from llm_pipeline.entity_extract_v1.dataloaders import EntityExtractV1JsonChunksLoader
- from llm_pipeline.entity_extract_v1.prompting import (
- EntityExtractV1JsonResponseParser,
- EntityExtractV1PromptBuilder,
- )
- from llm_pipeline.entity_extract_v1.factory import build_llm_client as build_extract_llm_client
- from llm_pipeline.entity_extract_eval_v1.prompting import (
- EntityEvalV1JsonResponseParser,
- EntityEvalV1PromptBuilder,
- )
- from llm_pipeline.entity_extract_eval_v1.factory import build_llm_client as build_eval_llm_client
- from llm_pipeline.rag_retrieval_eval_v1.factory import (
- build_rag_eval_pipeline_for_entities_items,
- build_rag_eval_pipeline_for_qa_json,
- )
- from llm_pipeline.entity_extract_v1.factory import build_pipeline_for_csv, build_pipeline_for_json
- from llm_pipeline.entity_extract_eval_v1.factory import build_eval_pipeline_for_json
- async def run_entity_extract_v1_with_json(
- input_json: str,
- output_json: str = "output_from_json.json",
- ) -> None:
- """使用 entity_extract_v1 版本:JSON → JSON 处理。"""
- pipeline, _ = build_pipeline_for_json(input_json=input_json, output_json=output_json)
- await pipeline.run()
- async def run_entity_extract_v1_with_csv(
- input_csv: str = "input.csv",
- output_csv: str = "output.csv",
- ) -> None:
- """使用 entity_extract_v1 版本:CSV → CSV 处理。"""
- pipeline, _ = build_pipeline_for_csv(input_csv=input_csv, output_csv=output_csv)
- await pipeline.run()
- async def run_entity_eval_v1_with_json(
- input_json: str,
- output_json: str = "output_from_json_eval.json",
- ) -> None:
- """使用 entity_extract_eval_v1 版本:对抽取结果做专业性评估与过滤。"""
- pipeline, _ = build_eval_pipeline_for_json(
- input_json=input_json,
- output_json=output_json,
- )
- await pipeline.run()
- async def run_full_entity_extract_and_eval() -> None:
- """一键运行:先抽取实体,再对结果进行评估过滤。"""
- raw_input = (
- "44_四川公路桥梁建设集团有限公司镇巴(川陕界)至广安高速公路通广段C合同段C4项目经理部_完整结果_20251212_155323.json"
- )
- first_output = "output_from_json.json"
- final_output = "output_from_json_eval.json"
- # 第一步:实体抽取
- await run_entity_extract_v1_with_json(input_json=raw_input, output_json=first_output)
- # 第二步:专业性评估与过滤
- await run_entity_eval_v1_with_json(input_json=first_output, output_json=final_output)
- async def run_rag_retrieval_eval_with_qa_json(
- input_json: str,
- output_csv: str = "rag_eval_results.csv",
- collection: str = "first_bfp_collection_test",
- hybrid_top_k: int = 20,
- final_top_k: int = 5,
- ) -> None:
- """
- 使用 rag_retrieval_eval_v1 版本:
- - 输入:单个包含 qa_pairs 的 JSON(与 batch_rag_eval_from_qa.py 兼容);
- - 过程:对每个实体 name 进行检索召回(multi_stage_recall),并调用 LLM 做命中率评估;
- - 输出:汇总结果写入 CSV,便于统计分析。
- """
- pipeline, _ = build_rag_eval_pipeline_for_qa_json(
- input_json=input_json,
- output_csv=output_csv,
- collection=collection,
- hybrid_top_k=hybrid_top_k,
- final_top_k=final_top_k,
- )
- await pipeline.run()
- class InMemoryListSaver(ResultSaver):
- """将流水线结果保存在内存列表中(不落地文件)。"""
- def __init__(self) -> None:
- self.items: List[Dict[str, Any]] = []
- async def save(self, item: Dict[str, Any], result: Dict[str, Any]) -> None:
- self.items.append({**item, **result})
- class InMemoryDataLoader(DataLoader):
- """从内存列表提供数据的 DataLoader。"""
- def __init__(self, items: List[Dict[str, Any]]) -> None:
- self._items = items
- async def load_items(self) -> AsyncIterator[Dict[str, Any]]:
- for it in self._items:
- yield it
- def get_total(self) -> Optional[int]:
- return len(self._items)
- class InMemoryEntityExtractSaver(ResultSaver):
- """对齐 entity_extract_v1 的 JSON 输出结构,但保存在内存。"""
- def __init__(self) -> None:
- self.items: List[Dict[str, Any]] = []
- async def save(self, item: Dict[str, Any], result: Dict[str, Any]) -> None:
- merged = {**item, **result}
- simplified = {
- "file_name": merged.get("file_name"),
- "chunk_id": merged.get("chunk_id"),
- "section_label": merged.get("section_label"),
- "text": merged.get("text"),
- "entity_extract_result": merged.get("entity_extract_result"),
- }
- self.items.append(simplified)
- class InMemoryEvalFilteredSaver(ResultSaver):
- """对齐 entity_extract_eval_v1 的过滤逻辑,但保存在内存。"""
- def __init__(self) -> None:
- self.items: List[Dict[str, Any]] = []
- async def save(self, item: Dict[str, Any], result: Dict[str, Any]) -> None:
- merged = {**item, **result}
- entities_obj = merged.get("entity_extract_result") or {}
- entities = entities_obj.get("entities") if isinstance(entities_obj, dict) else None
- if not entities or not isinstance(entities, list):
- return
- self.items.append(
- {
- "file_name": merged.get("file_name"),
- "chunk_id": merged.get("chunk_id"),
- "section_label": merged.get("section_label"),
- "text": merged.get("text"),
- "entity_extract_result": entities_obj,
- }
- )
- async def run_full_extract_eval_and_rag_eval_in_memory(
- input_json: str,
- output_csv: str = "rag_eval_results.csv",
- collection: str = "first_bfp_collection_test",
- hybrid_top_k: int = 20,
- final_top_k: int = 5,
- ) -> None:
- """
- 全流程(不依赖中间文件):
- 1) entity_extract_v1:从 input_json(chunks) 抽取实体概念+背景
- 2) entity_extract_eval_v1:专业性评估与过滤
- 3) rag_retrieval_eval_v1:用过滤后的实体(name+背景/证据拼 query)做检索召回 + 命中率评估,输出 CSV
- """
- def _iter_input_json_files(path_str: str) -> List[Path]:
- p = Path(path_str)
- if not p.exists():
- raise FileNotFoundError(f"输入路径不存在: {p}")
- if p.is_file():
- return [p]
- if p.is_dir():
- # 目录:递归找 json,固定排序保证可复现
- return sorted(p.rglob("*.json"), key=lambda x: str(x))
- return []
- input_files = _iter_input_json_files(input_json)
- if not input_files:
- print(f"[INFO] 未找到可处理的 JSON 文件: {input_json}")
- return
- all_filtered_items: List[Dict[str, Any]] = []
- # === Stage 1 + 2: per-file extract + eval filter (in-memory) ===
- extract_service = Path(__file__).parent / "llm_pipeline" / "entity_extract_v1" / "service.yaml"
- extract_cfg = YamlConfigProvider(service_path=extract_service)
- extract_client = build_extract_llm_client(extract_cfg)
- extract_prompt = EntityExtractV1PromptBuilder(cfg_provider=extract_cfg)
- extract_parser = EntityExtractV1JsonResponseParser(output_field="entity_extract_result")
- eval_service = Path(__file__).parent / "llm_pipeline" / "entity_extract_eval_v1" / "service.yaml"
- eval_cfg = YamlConfigProvider(service_path=eval_service)
- eval_client = build_eval_llm_client(eval_cfg)
- eval_prompt = EntityEvalV1PromptBuilder(cfg_provider=eval_cfg)
- eval_parser = EntityEvalV1JsonResponseParser(output_field="entity_extract_result")
- for fp in input_files:
- # === Stage 1: entity_extract_v1 (in-memory) ===
- extract_loader = EntityExtractV1JsonChunksLoader(str(fp))
- extract_saver = InMemoryEntityExtractSaver()
- extract_pipeline = LLMPipeline(
- llm_client=extract_client,
- config_provider=extract_cfg,
- data_loader=extract_loader,
- prompt_builder=extract_prompt,
- response_parser=extract_parser,
- result_saver=extract_saver,
- )
- await extract_pipeline.run()
- extracted_items = extract_saver.items
- if not extracted_items:
- print(f"[INFO] 跳过(抽取阶段无输出): {fp}")
- continue
- # === Stage 2: entity_extract_eval_v1 (in-memory) ===
- eval_loader = InMemoryDataLoader(extracted_items)
- eval_saver = InMemoryEvalFilteredSaver()
- eval_pipeline = LLMPipeline(
- llm_client=eval_client,
- config_provider=eval_cfg,
- data_loader=eval_loader,
- prompt_builder=eval_prompt,
- response_parser=eval_parser,
- result_saver=eval_saver,
- )
- await eval_pipeline.run()
- filtered_items = eval_saver.items
- if not filtered_items:
- print(f"[INFO] 跳过(评估过滤后无有效实体): {fp}")
- continue
- all_filtered_items.extend(filtered_items)
- # === Stage 3: rag_retrieval_eval_v1 (entities -> retrieval -> hit eval) ===
- if not all_filtered_items:
- print("[INFO] 全部输入处理完成,但未产生任何可用于 RAG 评估的实体。")
- return
- rag_pipeline, _ = build_rag_eval_pipeline_for_entities_items(
- items=all_filtered_items,
- # items=extracted_items,
- output_csv=output_csv,
- collection=collection,
- hybrid_top_k=hybrid_top_k,
- final_top_k=final_top_k,
- )
- await rag_pipeline.run()
- if __name__ == "__main__":
- # 默认执行“抽取 → 专业评估过滤 → 检索召回 → 命中率评估(CSV)”全流程(内存承接,不依赖中间文件)
- asyncio.run(
- run_full_extract_eval_and_rag_eval_in_memory(
- input_json="./data",
- output_csv="rag_eval_results.csv",
- collection="first_bfp_collection_test",
- hybrid_top_k=20,
- final_top_k=5,
- )
- )
|