|
|
@@ -0,0 +1,462 @@
|
|
|
+from __future__ import annotations
|
|
|
+
|
|
|
+import asyncio
|
|
|
+import json
|
|
|
+from pathlib import Path
|
|
|
+from typing import Any, AsyncIterator, Dict, List, Optional, Sequence
|
|
|
+
|
|
|
+from llm_pipeline.interfaces import DataLoader, ResultSaver
|
|
|
+from foundation.ai.rag.retrieval.retrieval import retrieval_manager
|
|
|
+
|
|
|
+
|
|
|
+def _build_query_text(qa: Dict[str, Any]) -> str:
|
|
|
+ """
|
|
|
+ 用“实体概念 + 背景/证据”等信息构造检索 query 文本。
|
|
|
+
|
|
|
+ 目的:避免只用 name 造成召回歧义,提高命中率。
|
|
|
+ """
|
|
|
+ if not isinstance(qa, dict):
|
|
|
+ return ""
|
|
|
+ name = str(qa.get("name") or "").strip()
|
|
|
+ qa_type = str(qa.get("type") or "").strip()
|
|
|
+ background = str(qa.get("background") or "").strip()
|
|
|
+ evidence = str(qa.get("evidence") or "").strip()
|
|
|
+
|
|
|
+ # 控制长度,避免 query 过长
|
|
|
+ def _clip(s: str, n: int) -> str:
|
|
|
+ s = s.strip()
|
|
|
+ return s if len(s) <= n else s[:n]
|
|
|
+
|
|
|
+ parts: List[str] = []
|
|
|
+ if name:
|
|
|
+ parts.append(name)
|
|
|
+ if qa_type:
|
|
|
+ parts.append(_clip(qa_type, 50))
|
|
|
+ if background:
|
|
|
+ parts.append(_clip(background, 200))
|
|
|
+ if evidence:
|
|
|
+ parts.append(_clip(evidence, 200))
|
|
|
+
|
|
|
+ # 用分隔符保持可读性
|
|
|
+ return ",".join([p for p in parts if p])
|
|
|
+ # return f"{name}"
|
|
|
+
|
|
|
+
|
|
|
+class RagEvalFromQaJsonLoader(DataLoader):
|
|
|
+ """
|
|
|
+ 从带有 qa_pairs 结构的 JSON 中加载实体,并在加载阶段完成检索召回。
|
|
|
+
|
|
|
+ 输入 JSON 结构示例(与 batch_rag_eval_from_qa.py 保持兼容):
|
|
|
+ {
|
|
|
+ "qa_pairs": [
|
|
|
+ {
|
|
|
+ "q": "原始文本片段……",
|
|
|
+ "a": [
|
|
|
+ {
|
|
|
+ "name": "实体名称",
|
|
|
+ "type": "实体类型",
|
|
|
+ "background": "实体背景",
|
|
|
+ "evidence": "证据片段"
|
|
|
+ },
|
|
|
+ ...
|
|
|
+ ],
|
|
|
+ "chunk_id": 1,
|
|
|
+ "section_label": "某章节"
|
|
|
+ },
|
|
|
+ ...
|
|
|
+ ]
|
|
|
+ }
|
|
|
+
|
|
|
+ 本 Loader 会为每个实体构造一条 item,字段包括:
|
|
|
+ - source_file / chunk_id / section_label
|
|
|
+ - original_text: 原始 q 文本
|
|
|
+ - qa: 原始实体字典
|
|
|
+ - candidate_texts: 经过 multi_stage_recall 后得到的候选文本列表
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ json_path: str | Path,
|
|
|
+ collection: str,
|
|
|
+ hybrid_top_k: int = 50,
|
|
|
+ final_top_k: int = 5,
|
|
|
+ retrieval_concurrency: int = 10,
|
|
|
+ ) -> None:
|
|
|
+ self._json_path = Path(json_path)
|
|
|
+ self._collection = collection
|
|
|
+ self._hybrid_top_k = int(hybrid_top_k)
|
|
|
+ self._final_top_k = int(final_top_k)
|
|
|
+ self._retrieval_concurrency = max(1, int(retrieval_concurrency))
|
|
|
+
|
|
|
+ self._items: List[Dict[str, Any]] | None = None
|
|
|
+
|
|
|
+ async def _retrieve_one(self, query_text: str) -> List[Dict[str, Any]]:
|
|
|
+ """单次检索(同步 multi_stage_recall 用线程池包装)。"""
|
|
|
+ return await asyncio.to_thread(
|
|
|
+ retrieval_manager.multi_stage_recall,
|
|
|
+ self._collection,
|
|
|
+ query_text,
|
|
|
+ self._hybrid_top_k,
|
|
|
+ self._final_top_k,
|
|
|
+ )
|
|
|
+
|
|
|
+ async def _build_items(self) -> None:
|
|
|
+ if self._items is not None:
|
|
|
+ return
|
|
|
+
|
|
|
+ if not self._json_path.exists():
|
|
|
+ self._items = []
|
|
|
+ return
|
|
|
+
|
|
|
+ with self._json_path.open("r", encoding="utf-8") as f:
|
|
|
+ data = json.load(f)
|
|
|
+
|
|
|
+ qa_pairs = data.get("qa_pairs", [])
|
|
|
+ if not isinstance(qa_pairs, list):
|
|
|
+ self._items = []
|
|
|
+ return
|
|
|
+
|
|
|
+ # 先收集所有检索任务的元信息,再并发执行检索
|
|
|
+ metas: List[Dict[str, Any]] = []
|
|
|
+ idx = 0
|
|
|
+ for pair in qa_pairs:
|
|
|
+ if not isinstance(pair, dict):
|
|
|
+ continue
|
|
|
+ q_text: str = pair.get("q", "") or ""
|
|
|
+ a_list = pair.get("a", []) or []
|
|
|
+ chunk_id = pair.get("chunk_id")
|
|
|
+ section_label = pair.get("section_label")
|
|
|
+
|
|
|
+ if not isinstance(a_list, list):
|
|
|
+ continue
|
|
|
+
|
|
|
+ for ent in a_list:
|
|
|
+ if not isinstance(ent, dict):
|
|
|
+ continue
|
|
|
+ ent_name = ent.get("name")
|
|
|
+ if not ent_name:
|
|
|
+ continue
|
|
|
+
|
|
|
+ query_text = _build_query_text(ent) or str(ent_name)
|
|
|
+ metas.append(
|
|
|
+ {
|
|
|
+ "_idx": idx,
|
|
|
+ "source_file": self._json_path.name,
|
|
|
+ "chunk_id": chunk_id,
|
|
|
+ "section_label": section_label,
|
|
|
+ "original_text": q_text,
|
|
|
+ "qa": ent,
|
|
|
+ "query_text": query_text,
|
|
|
+ "ent_name": ent_name,
|
|
|
+ }
|
|
|
+ )
|
|
|
+ idx += 1
|
|
|
+
|
|
|
+ if not metas:
|
|
|
+ self._items = []
|
|
|
+ return
|
|
|
+
|
|
|
+ # 有界并发执行检索:避免一次性创建过多 task
|
|
|
+ built_by_idx: Dict[int, Dict[str, Any]] = {}
|
|
|
+ tasks: set[asyncio.Task[List[Dict[str, Any]]]] = set()
|
|
|
+ task2meta: Dict[asyncio.Task[List[Dict[str, Any]]], Dict[str, Any]] = {}
|
|
|
+
|
|
|
+ def _schedule(meta: Dict[str, Any]) -> None:
|
|
|
+ t = asyncio.create_task(self._retrieve_one(meta["query_text"]))
|
|
|
+ tasks.add(t)
|
|
|
+ task2meta[t] = meta
|
|
|
+
|
|
|
+ it = iter(metas)
|
|
|
+ for _ in range(min(self._retrieval_concurrency, len(metas))):
|
|
|
+ _schedule(next(it))
|
|
|
+
|
|
|
+ while tasks:
|
|
|
+ done, tasks = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
|
|
+ for t in done:
|
|
|
+ meta = task2meta.pop(t, None) or {}
|
|
|
+ raw_results: List[Dict[str, Any]]
|
|
|
+ try:
|
|
|
+ raw_results = t.result()
|
|
|
+ except Exception as exc: # noqa: BLE001
|
|
|
+ raw_results = []
|
|
|
+ print(
|
|
|
+ f"[RagEvalFromQaJsonLoader] 检索异常 "
|
|
|
+ f"(file={meta.get('source_file')}, chunk_id={meta.get('chunk_id')}, ent={meta.get('ent_name')}): {exc}"
|
|
|
+ )
|
|
|
+
|
|
|
+ candidate_texts: List[str] = [
|
|
|
+ r.get("text_content", "") for r in (raw_results or [])[: self._final_top_k]
|
|
|
+ ]
|
|
|
+
|
|
|
+ i = int(meta.get("_idx", -1))
|
|
|
+ built_by_idx[i] = {
|
|
|
+ "source_file": meta.get("source_file"),
|
|
|
+ "chunk_id": meta.get("chunk_id"),
|
|
|
+ "section_label": meta.get("section_label"),
|
|
|
+ "original_text": meta.get("original_text"),
|
|
|
+ "qa": meta.get("qa"),
|
|
|
+ "query_text": meta.get("query_text"),
|
|
|
+ "candidate_texts": candidate_texts,
|
|
|
+ "retrieval_raw_results": raw_results,
|
|
|
+ }
|
|
|
+
|
|
|
+ # 补充调度后续任务
|
|
|
+ try:
|
|
|
+ _schedule(next(it))
|
|
|
+ except StopIteration:
|
|
|
+ pass
|
|
|
+
|
|
|
+ # 按原始 idx 排序,保持输出稳定
|
|
|
+ built_items = [built_by_idx[i] for i in sorted(built_by_idx.keys()) if i >= 0]
|
|
|
+
|
|
|
+ self._items = built_items
|
|
|
+
|
|
|
+ async def load_items(self) -> AsyncIterator[Dict[str, Any]]:
|
|
|
+ await self._build_items()
|
|
|
+ assert self._items is not None
|
|
|
+ for item in self._items:
|
|
|
+ yield item
|
|
|
+
|
|
|
+ def get_total(self) -> Optional[int]:
|
|
|
+ # 为避免在同步环境再次触发检索,这里仅在已构建缓存时返回长度
|
|
|
+ if self._items is None:
|
|
|
+ return None
|
|
|
+ return len(self._items)
|
|
|
+
|
|
|
+
|
|
|
+class RagEvalFromEntitiesItemsLoader(DataLoader):
|
|
|
+ """
|
|
|
+ 基于“上一组件已抽取好的实体结果”的内存数据做检索加载。
|
|
|
+
|
|
|
+ 预期上游结果结构与 `entity_extract_eval_v1` 的输出类似:
|
|
|
+ [
|
|
|
+ {
|
|
|
+ "file_name": ...,
|
|
|
+ "section_label": ...,
|
|
|
+ "text": "... 原文片段 ...",
|
|
|
+ "entity_extract_result": {
|
|
|
+ "entities": [
|
|
|
+ { "name": ..., "type": ..., "background": ..., "evidence": ... },
|
|
|
+ ...
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ },
|
|
|
+ ...
|
|
|
+ ]
|
|
|
+
|
|
|
+ 本 Loader 不再从文件读取,而是直接接受上述列表(或任意等价结构)作为入参;
|
|
|
+ 对每个实体调用 multi_stage_recall 并构造评估用的 item。
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(
|
|
|
+ self,
|
|
|
+ items: Sequence[Dict[str, Any]],
|
|
|
+ collection: str,
|
|
|
+ hybrid_top_k: int = 50,
|
|
|
+ final_top_k: int = 5,
|
|
|
+ retrieval_concurrency: int = 10,
|
|
|
+ ) -> None:
|
|
|
+ self._source_items: List[Dict[str, Any]] = list(items)
|
|
|
+ self._collection = collection
|
|
|
+ self._hybrid_top_k = int(hybrid_top_k)
|
|
|
+ self._final_top_k = int(final_top_k)
|
|
|
+ self._retrieval_concurrency = max(1, int(retrieval_concurrency))
|
|
|
+
|
|
|
+ self._built_items: List[Dict[str, Any]] | None = None
|
|
|
+
|
|
|
+ async def _retrieve_one(self, query_text: str) -> List[Dict[str, Any]]:
|
|
|
+ """单次检索(同步 multi_stage_recall 用线程池包装)。"""
|
|
|
+ return await asyncio.to_thread(
|
|
|
+ retrieval_manager.multi_stage_recall,
|
|
|
+ self._collection,
|
|
|
+ query_text,
|
|
|
+ self._hybrid_top_k,
|
|
|
+ self._final_top_k,
|
|
|
+ )
|
|
|
+
|
|
|
+ async def _build_items(self) -> None:
|
|
|
+ if self._built_items is not None:
|
|
|
+ return
|
|
|
+
|
|
|
+ metas: List[Dict[str, Any]] = []
|
|
|
+ idx = 0
|
|
|
+ for item in self._source_items:
|
|
|
+ if not isinstance(item, dict):
|
|
|
+ continue
|
|
|
+
|
|
|
+ file_name = item.get("file_name")
|
|
|
+ section_label = item.get("section_label")
|
|
|
+ original_text = item.get("text", "") or ""
|
|
|
+
|
|
|
+ ent_obj = item.get("entity_extract_result") or {}
|
|
|
+ entities = ent_obj.get("entities") if isinstance(ent_obj, dict) else None
|
|
|
+ if not entities or not isinstance(entities, list):
|
|
|
+ continue
|
|
|
+
|
|
|
+ for ent in entities:
|
|
|
+ if not isinstance(ent, dict):
|
|
|
+ continue
|
|
|
+ ent_name = ent.get("name")
|
|
|
+ if not ent_name:
|
|
|
+ continue
|
|
|
+
|
|
|
+ query_text = _build_query_text(ent) or str(ent_name)
|
|
|
+ metas.append(
|
|
|
+ {
|
|
|
+ "_idx": idx,
|
|
|
+ "source_file": file_name,
|
|
|
+ "chunk_id": item.get("chunk_id"),
|
|
|
+ "section_label": section_label,
|
|
|
+ "original_text": original_text,
|
|
|
+ "qa": ent,
|
|
|
+ "query_text": query_text,
|
|
|
+ "ent_name": ent_name,
|
|
|
+ }
|
|
|
+ )
|
|
|
+ idx += 1
|
|
|
+
|
|
|
+ if not metas:
|
|
|
+ self._built_items = []
|
|
|
+ return
|
|
|
+
|
|
|
+ built_by_idx: Dict[int, Dict[str, Any]] = {}
|
|
|
+ tasks: set[asyncio.Task[List[Dict[str, Any]]]] = set()
|
|
|
+ task2meta: Dict[asyncio.Task[List[Dict[str, Any]]], Dict[str, Any]] = {}
|
|
|
+
|
|
|
+ def _schedule(meta: Dict[str, Any]) -> None:
|
|
|
+ t = asyncio.create_task(self._retrieve_one(meta["query_text"]))
|
|
|
+ tasks.add(t)
|
|
|
+ task2meta[t] = meta
|
|
|
+
|
|
|
+ it = iter(metas)
|
|
|
+ for _ in range(min(self._retrieval_concurrency, len(metas))):
|
|
|
+ _schedule(next(it))
|
|
|
+
|
|
|
+ while tasks:
|
|
|
+ done, tasks = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
|
|
+ for t in done:
|
|
|
+ meta = task2meta.pop(t, None) or {}
|
|
|
+ raw_results: List[Dict[str, Any]]
|
|
|
+ try:
|
|
|
+ raw_results = t.result()
|
|
|
+ except Exception as exc: # noqa: BLE001
|
|
|
+ raw_results = []
|
|
|
+ print(
|
|
|
+ f"[RagEvalFromEntitiesItemsLoader] 检索异常 "
|
|
|
+ f"(file={meta.get('source_file')}, ent={meta.get('ent_name')}): {exc}"
|
|
|
+ )
|
|
|
+
|
|
|
+ candidate_texts: List[str] = [
|
|
|
+ r.get("text_content", "") for r in (raw_results or [])[: self._final_top_k]
|
|
|
+ ]
|
|
|
+
|
|
|
+ i = int(meta.get("_idx", -1))
|
|
|
+ built_by_idx[i] = {
|
|
|
+ "source_file": meta.get("source_file"),
|
|
|
+ "chunk_id": meta.get("chunk_id"),
|
|
|
+ "section_label": meta.get("section_label"),
|
|
|
+ "original_text": meta.get("original_text"),
|
|
|
+ "qa": meta.get("qa"),
|
|
|
+ "query_text": meta.get("query_text"),
|
|
|
+ "candidate_texts": candidate_texts,
|
|
|
+ "retrieval_raw_results": raw_results,
|
|
|
+ }
|
|
|
+
|
|
|
+ try:
|
|
|
+ _schedule(next(it))
|
|
|
+ except StopIteration:
|
|
|
+ pass
|
|
|
+
|
|
|
+ built_items = [built_by_idx[i] for i in sorted(built_by_idx.keys()) if i >= 0]
|
|
|
+
|
|
|
+ self._built_items = built_items
|
|
|
+
|
|
|
+ async def load_items(self) -> AsyncIterator[Dict[str, Any]]:
|
|
|
+ await self._build_items()
|
|
|
+ assert self._built_items is not None
|
|
|
+ for item in self._built_items:
|
|
|
+ yield item
|
|
|
+
|
|
|
+ def get_total(self) -> Optional[int]:
|
|
|
+ if self._built_items is None:
|
|
|
+ return None
|
|
|
+ return len(self._built_items)
|
|
|
+
|
|
|
+
|
|
|
+class RagEvalCsvResultSaver(ResultSaver):
|
|
|
+ """
|
|
|
+ 将检索+评估结果写入 CSV。
|
|
|
+
|
|
|
+ 字段与 batch_rag_eval_from_qa.py 中 write_csv 基本保持一致:
|
|
|
+ - source_file / chunk_id / section_label
|
|
|
+ - original_text / entity_name / entity_type / entity_background / entity_evidence
|
|
|
+ - candidate_texts(JSON 字符串)
|
|
|
+ - eval_label / eval_hit / eval_best_answer_index / eval_reason / eval_raw_output
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, csv_path: str | Path) -> None:
|
|
|
+ import csv # 局部导入以避免未使用警告
|
|
|
+
|
|
|
+ self._csv_path = Path(csv_path)
|
|
|
+ self._initialized = False
|
|
|
+ self._csv_module = csv
|
|
|
+
|
|
|
+ def _ensure_header(self, fieldnames: List[str]) -> None:
|
|
|
+ if self._initialized and self._csv_path.exists():
|
|
|
+ return
|
|
|
+ with self._csv_path.open("w", newline="", encoding="utf-8-sig") as f:
|
|
|
+ writer = self._csv_module.DictWriter(f, fieldnames=fieldnames)
|
|
|
+ writer.writeheader()
|
|
|
+ self._initialized = True
|
|
|
+
|
|
|
+ async def save(self, item: Dict[str, Any], result: Dict[str, Any]) -> None:
|
|
|
+ merged = {**item, **result}
|
|
|
+
|
|
|
+ qa = merged.get("qa") or {}
|
|
|
+ row: Dict[str, Any] = {
|
|
|
+ "source_file": merged.get("source_file"),
|
|
|
+ "chunk_id": merged.get("chunk_id"),
|
|
|
+ "section_label": merged.get("section_label"),
|
|
|
+ "original_text": merged.get("original_text", ""),
|
|
|
+ "entity_name": qa.get("name"),
|
|
|
+ "entity_type": qa.get("type"),
|
|
|
+ "entity_background": qa.get("background"),
|
|
|
+ "entity_evidence": qa.get("evidence"),
|
|
|
+ "query_text": merged.get("query_text", ""),
|
|
|
+ "candidate_texts": merged.get("candidate_texts", []),
|
|
|
+ "eval_label": merged.get("eval_label"),
|
|
|
+ "eval_hit": merged.get("eval_hit"),
|
|
|
+ "eval_best_answer_index": merged.get("eval_best_answer_index"),
|
|
|
+ "eval_reason": merged.get("eval_reason"),
|
|
|
+ "eval_raw_output": merged.get("eval_raw_output"),
|
|
|
+ }
|
|
|
+
|
|
|
+ # candidate_texts 序列化为 JSON 字符串,避免换行和引号问题
|
|
|
+ if isinstance(row.get("candidate_texts"), (list, dict)):
|
|
|
+ row["candidate_texts"] = json.dumps(row["candidate_texts"], ensure_ascii=False)
|
|
|
+
|
|
|
+ fieldnames = [
|
|
|
+ "source_file",
|
|
|
+ "chunk_id",
|
|
|
+ "section_label",
|
|
|
+ "original_text",
|
|
|
+ "entity_name",
|
|
|
+ "entity_type",
|
|
|
+ "entity_background",
|
|
|
+ "entity_evidence",
|
|
|
+ "query_text",
|
|
|
+ "candidate_texts",
|
|
|
+ "eval_label",
|
|
|
+ "eval_hit",
|
|
|
+ "eval_best_answer_index",
|
|
|
+ "eval_reason",
|
|
|
+ "eval_raw_output",
|
|
|
+ ]
|
|
|
+ self._ensure_header(fieldnames)
|
|
|
+
|
|
|
+ with self._csv_path.open("a", newline="", encoding="utf-8-sig") as f:
|
|
|
+ writer = self._csv_module.DictWriter(f, fieldnames=fieldnames)
|
|
|
+ writer.writerow(row)
|
|
|
+
|
|
|
+
|
|
|
+
|