main.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. from __future__ import annotations
  2. import asyncio
  3. from pathlib import Path
  4. from typing import Any, AsyncIterator, Dict, List, Optional
  5. from llm_pipeline.core.config import YamlConfigProvider
  6. from llm_pipeline.core.pipeline import LLMPipeline
  7. from llm_pipeline.interfaces import DataLoader, ResultSaver
  8. from llm_pipeline.entity_extract_v1.dataloaders import EntityExtractV1JsonChunksLoader
  9. from llm_pipeline.entity_extract_v1.prompting import (
  10. EntityExtractV1JsonResponseParser,
  11. EntityExtractV1PromptBuilder,
  12. )
  13. from llm_pipeline.entity_extract_v1.factory import build_llm_client as build_extract_llm_client
  14. from llm_pipeline.entity_extract_eval_v1.prompting import (
  15. EntityEvalV1JsonResponseParser,
  16. EntityEvalV1PromptBuilder,
  17. )
  18. from llm_pipeline.entity_extract_eval_v1.factory import build_llm_client as build_eval_llm_client
  19. from llm_pipeline.rag_retrieval_eval_v1.factory import (
  20. build_rag_eval_pipeline_for_entities_items,
  21. build_rag_eval_pipeline_for_qa_json,
  22. )
  23. from llm_pipeline.entity_extract_v1.factory import build_pipeline_for_csv, build_pipeline_for_json
  24. from llm_pipeline.entity_extract_eval_v1.factory import build_eval_pipeline_for_json
  25. async def run_entity_extract_v1_with_json(
  26. input_json: str,
  27. output_json: str = "output_from_json.json",
  28. ) -> None:
  29. """使用 entity_extract_v1 版本:JSON → JSON 处理。"""
  30. pipeline, _ = build_pipeline_for_json(input_json=input_json, output_json=output_json)
  31. await pipeline.run()
  32. async def run_entity_extract_v1_with_csv(
  33. input_csv: str = "input.csv",
  34. output_csv: str = "output.csv",
  35. ) -> None:
  36. """使用 entity_extract_v1 版本:CSV → CSV 处理。"""
  37. pipeline, _ = build_pipeline_for_csv(input_csv=input_csv, output_csv=output_csv)
  38. await pipeline.run()
  39. async def run_entity_eval_v1_with_json(
  40. input_json: str,
  41. output_json: str = "output_from_json_eval.json",
  42. ) -> None:
  43. """使用 entity_extract_eval_v1 版本:对抽取结果做专业性评估与过滤。"""
  44. pipeline, _ = build_eval_pipeline_for_json(
  45. input_json=input_json,
  46. output_json=output_json,
  47. )
  48. await pipeline.run()
  49. async def run_full_entity_extract_and_eval() -> None:
  50. """一键运行:先抽取实体,再对结果进行评估过滤。"""
  51. raw_input = (
  52. "44_四川公路桥梁建设集团有限公司镇巴(川陕界)至广安高速公路通广段C合同段C4项目经理部_完整结果_20251212_155323.json"
  53. )
  54. first_output = "output_from_json.json"
  55. final_output = "output_from_json_eval.json"
  56. # 第一步:实体抽取
  57. await run_entity_extract_v1_with_json(input_json=raw_input, output_json=first_output)
  58. # 第二步:专业性评估与过滤
  59. await run_entity_eval_v1_with_json(input_json=first_output, output_json=final_output)
  60. async def run_rag_retrieval_eval_with_qa_json(
  61. input_json: str,
  62. output_csv: str = "rag_eval_results.csv",
  63. collection: str = "first_bfp_collection_test",
  64. hybrid_top_k: int = 20,
  65. final_top_k: int = 5,
  66. ) -> None:
  67. """
  68. 使用 rag_retrieval_eval_v1 版本:
  69. - 输入:单个包含 qa_pairs 的 JSON(与 batch_rag_eval_from_qa.py 兼容);
  70. - 过程:对每个实体 name 进行检索召回(multi_stage_recall),并调用 LLM 做命中率评估;
  71. - 输出:汇总结果写入 CSV,便于统计分析。
  72. """
  73. pipeline, _ = build_rag_eval_pipeline_for_qa_json(
  74. input_json=input_json,
  75. output_csv=output_csv,
  76. collection=collection,
  77. hybrid_top_k=hybrid_top_k,
  78. final_top_k=final_top_k,
  79. )
  80. await pipeline.run()
  81. class InMemoryListSaver(ResultSaver):
  82. """将流水线结果保存在内存列表中(不落地文件)。"""
  83. def __init__(self) -> None:
  84. self.items: List[Dict[str, Any]] = []
  85. async def save(self, item: Dict[str, Any], result: Dict[str, Any]) -> None:
  86. self.items.append({**item, **result})
  87. class InMemoryDataLoader(DataLoader):
  88. """从内存列表提供数据的 DataLoader。"""
  89. def __init__(self, items: List[Dict[str, Any]]) -> None:
  90. self._items = items
  91. async def load_items(self) -> AsyncIterator[Dict[str, Any]]:
  92. for it in self._items:
  93. yield it
  94. def get_total(self) -> Optional[int]:
  95. return len(self._items)
  96. class InMemoryEntityExtractSaver(ResultSaver):
  97. """对齐 entity_extract_v1 的 JSON 输出结构,但保存在内存。"""
  98. def __init__(self) -> None:
  99. self.items: List[Dict[str, Any]] = []
  100. async def save(self, item: Dict[str, Any], result: Dict[str, Any]) -> None:
  101. merged = {**item, **result}
  102. simplified = {
  103. "file_name": merged.get("file_name"),
  104. "chunk_id": merged.get("chunk_id"),
  105. "section_label": merged.get("section_label"),
  106. "text": merged.get("text"),
  107. "entity_extract_result": merged.get("entity_extract_result"),
  108. }
  109. self.items.append(simplified)
  110. class InMemoryEvalFilteredSaver(ResultSaver):
  111. """对齐 entity_extract_eval_v1 的过滤逻辑,但保存在内存。"""
  112. def __init__(self) -> None:
  113. self.items: List[Dict[str, Any]] = []
  114. async def save(self, item: Dict[str, Any], result: Dict[str, Any]) -> None:
  115. merged = {**item, **result}
  116. entities_obj = merged.get("entity_extract_result") or {}
  117. entities = entities_obj.get("entities") if isinstance(entities_obj, dict) else None
  118. if not entities or not isinstance(entities, list):
  119. return
  120. self.items.append(
  121. {
  122. "file_name": merged.get("file_name"),
  123. "chunk_id": merged.get("chunk_id"),
  124. "section_label": merged.get("section_label"),
  125. "text": merged.get("text"),
  126. "entity_extract_result": entities_obj,
  127. }
  128. )
  129. async def run_full_extract_eval_and_rag_eval_in_memory(
  130. input_json: str,
  131. output_csv: str = "rag_eval_results.csv",
  132. collection: str = "first_bfp_collection_test",
  133. hybrid_top_k: int = 20,
  134. final_top_k: int = 5,
  135. ) -> None:
  136. """
  137. 全流程(不依赖中间文件):
  138. 1) entity_extract_v1:从 input_json(chunks) 抽取实体概念+背景
  139. 2) entity_extract_eval_v1:专业性评估与过滤
  140. 3) rag_retrieval_eval_v1:用过滤后的实体(name+背景/证据拼 query)做检索召回 + 命中率评估,输出 CSV
  141. """
  142. def _iter_input_json_files(path_str: str) -> List[Path]:
  143. p = Path(path_str)
  144. if not p.exists():
  145. raise FileNotFoundError(f"输入路径不存在: {p}")
  146. if p.is_file():
  147. return [p]
  148. if p.is_dir():
  149. # 目录:递归找 json,固定排序保证可复现
  150. return sorted(p.rglob("*.json"), key=lambda x: str(x))
  151. return []
  152. input_files = _iter_input_json_files(input_json)
  153. if not input_files:
  154. print(f"[INFO] 未找到可处理的 JSON 文件: {input_json}")
  155. return
  156. all_filtered_items: List[Dict[str, Any]] = []
  157. # === Stage 1 + 2: per-file extract + eval filter (in-memory) ===
  158. extract_service = Path(__file__).parent / "llm_pipeline" / "entity_extract_v1" / "service.yaml"
  159. extract_cfg = YamlConfigProvider(service_path=extract_service)
  160. extract_client = build_extract_llm_client(extract_cfg)
  161. extract_prompt = EntityExtractV1PromptBuilder(cfg_provider=extract_cfg)
  162. extract_parser = EntityExtractV1JsonResponseParser(output_field="entity_extract_result")
  163. eval_service = Path(__file__).parent / "llm_pipeline" / "entity_extract_eval_v1" / "service.yaml"
  164. eval_cfg = YamlConfigProvider(service_path=eval_service)
  165. eval_client = build_eval_llm_client(eval_cfg)
  166. eval_prompt = EntityEvalV1PromptBuilder(cfg_provider=eval_cfg)
  167. eval_parser = EntityEvalV1JsonResponseParser(output_field="entity_extract_result")
  168. for fp in input_files:
  169. # === Stage 1: entity_extract_v1 (in-memory) ===
  170. extract_loader = EntityExtractV1JsonChunksLoader(str(fp))
  171. extract_saver = InMemoryEntityExtractSaver()
  172. extract_pipeline = LLMPipeline(
  173. llm_client=extract_client,
  174. config_provider=extract_cfg,
  175. data_loader=extract_loader,
  176. prompt_builder=extract_prompt,
  177. response_parser=extract_parser,
  178. result_saver=extract_saver,
  179. )
  180. await extract_pipeline.run()
  181. extracted_items = extract_saver.items
  182. if not extracted_items:
  183. print(f"[INFO] 跳过(抽取阶段无输出): {fp}")
  184. continue
  185. # === Stage 2: entity_extract_eval_v1 (in-memory) ===
  186. eval_loader = InMemoryDataLoader(extracted_items)
  187. eval_saver = InMemoryEvalFilteredSaver()
  188. eval_pipeline = LLMPipeline(
  189. llm_client=eval_client,
  190. config_provider=eval_cfg,
  191. data_loader=eval_loader,
  192. prompt_builder=eval_prompt,
  193. response_parser=eval_parser,
  194. result_saver=eval_saver,
  195. )
  196. await eval_pipeline.run()
  197. filtered_items = eval_saver.items
  198. if not filtered_items:
  199. print(f"[INFO] 跳过(评估过滤后无有效实体): {fp}")
  200. continue
  201. all_filtered_items.extend(filtered_items)
  202. # === Stage 3: rag_retrieval_eval_v1 (entities -> retrieval -> hit eval) ===
  203. if not all_filtered_items:
  204. print("[INFO] 全部输入处理完成,但未产生任何可用于 RAG 评估的实体。")
  205. return
  206. rag_pipeline, _ = build_rag_eval_pipeline_for_entities_items(
  207. items=all_filtered_items,
  208. # items=extracted_items,
  209. output_csv=output_csv,
  210. collection=collection,
  211. hybrid_top_k=hybrid_top_k,
  212. final_top_k=final_top_k,
  213. )
  214. await rag_pipeline.run()
  215. if __name__ == "__main__":
  216. # 默认执行“抽取 → 专业评估过滤 → 检索召回 → 命中率评估(CSV)”全流程(内存承接,不依赖中间文件)
  217. asyncio.run(
  218. run_full_extract_eval_and_rag_eval_in_memory(
  219. input_json="./data",
  220. output_csv="rag_eval_results.csv",
  221. collection="first_bfp_collection_test",
  222. hybrid_top_k=20,
  223. final_top_k=5,
  224. )
  225. )