Prechádzať zdrojové kódy

v0.0.4:新增父段补充

ZengChao 1 mesiac pred
rodič
commit
65fda42fec

+ 96 - 9
core/construction_review/component/ai_review_engine.py

@@ -60,7 +60,7 @@ from foundation.ai.rag.retrieval.entities_enhance import entity_enhance
 from core.construction_review.component.reviewers.base_reviewer import BaseReviewer
 from core.construction_review.component.reviewers.outline_reviewer import OutlineReviewer
 from core.construction_review.component.reviewers.utils.text_split import split_text
-
+from core.construction_review.component.infrastructure.milvus import MilvusManager, MilvusConfig
 
 
 
@@ -130,6 +130,8 @@ class AIReviewEngine(BaseReviewer):
         self.milvus_collection = config_handler.get('milvus', 'MILVUS_COLLECTION', 'default')
         self.outline_reviewer = OutlineReviewer()
 
+        self.milvus = MilvusManager(MilvusConfig())
+
     def _process_review_result(self, result):
         """
         处理审查结果,统一转换为字典格式
@@ -348,11 +350,6 @@ class AIReviewEngine(BaseReviewer):
         logger.info("构建查询对")
         query_pairs = query_rewrite_manager.query_extract(query_content)
         bfp_result_lists =entity_enhance.entities_enhance_retrieval(query_pairs)
-        # 使用bfp_result_list 获取 parent_id ,通过parent_id 获取父文档内容 utils_test\Milvus_Test\test_查询接口.py
-        # llm 异步相关度分析  判断父文档是否与query_content 审查条文相关
-        # 如果相关,则追加到 bfp_result,如果不相关则,则跳过
-        # 如果len(bfp_result) > 0 则进行RAG增强,否则 则返回空
-
         logger.info(f"bfp_result_lists{bfp_result_lists}")
         # 检查是否有检索结果
         if not bfp_result_lists:
@@ -364,6 +361,98 @@ class AIReviewEngine(BaseReviewer):
                 'text_content': '',
                 'metadata': {}
             }
+        #todo
+        #异步调用查询。查出所有的
+        
+        #todo
+        # 使用bfp_result_list 获取 parent_id ,通过parent_id 获取父文档内容 utils_test\Milvus_Test\test_查询接口.py
+        # llm 异步相关度分析  判断父文档是否与query_content 审查条文相关
+        # 如果相关,则追加到 bfp_result,如果不相关则,则跳过
+        import asyncio
+        import concurrent.futures
+        from typing import Any, Dict, List, Optional, Sequence
+        from core.construction_review.component.infrastructure.relevance import is_relevant_async
+        PARENT_COLLECTION = "rag_parent_hybrid"  # TODO: 改成你的父段 collection
+        PARENT_TEXT_FIELD = "text"                   # TODO: 改成你的父段字段名
+        PARENT_OUTPUT_FIELDS: Sequence[str] = ["parent_id", PARENT_TEXT_FIELD]
+
+        def run_async(coro):
+            """在同步函数中跑 async(兼容已有 event loop)"""
+            try:
+                asyncio.get_running_loop()
+                with concurrent.futures.ThreadPoolExecutor() as executor:
+                    return executor.submit(asyncio.run, coro).result()
+            except RuntimeError:
+                return asyncio.run(coro)
+
+        async def _async_condition_query_one(pid: str) -> Optional[Dict[str, Any]]:
+            """
+            condition_query 是同步:用线程池包成 async
+            返回父段 row(或 None)
+            """
+            loop = asyncio.get_running_loop()
+
+
+            def _call():
+                rows = self.milvus.condition_query(
+                    collection_name=PARENT_COLLECTION,
+                    filter=f"parent_id == '{pid}'",
+                    output_fields=PARENT_OUTPUT_FIELDS,
+                    limit=1,
+                )
+                if not rows:
+                    return None
+                row0 = rows[0] or {}
+                # 白名单投影:避免 pk/id 等多余字段
+                return {k: row0.get(k) for k in PARENT_OUTPUT_FIELDS if k in row0}
+
+            return await loop.run_in_executor(None, _call)
+
+        async def _enhance_all():
+            # 1) 收集 parent_id -> 指向哪些 result 需要被拼接
+            pid_to_results: Dict[str, List[Dict[str, Any]]] = {}
+
+            for result_list in bfp_result_lists:
+                for r in (result_list or []):
+                    md = r.get("metadata") or {}
+                    pid = md.get("parent_id")
+                    if not pid:
+                        continue
+                    pid = str(pid)
+                    pid_to_results.setdefault(pid, []).append(r)
+
+            if not pid_to_results:
+                return
+
+            # 2) 逐个 parent_id 串行:查父段 -> LLM 判断 -> 拼接到对应 results
+            for pid, results in pid_to_results.items():
+                parent_doc = await _async_condition_query_one(pid)
+
+                if not parent_doc:
+                    continue
+
+                parent_text = (parent_doc.get(PARENT_TEXT_FIELD) or "").strip()
+                if not parent_text:
+                    continue
+
+                # LLM 判断是否相关(你已经封装好了 is_relevant_async:模型直接输出 relevant true/false)
+                relevant = await is_relevant_async(query_content, parent_text)
+                print("================\n")
+                print(relevant)
+
+                print("\n================\n")
+                # if not relevant:
+                #     continue
+
+                extra = (
+                    f"{parent_text}\n"
+                )
+
+                # 3) 拼接到所有属于该 parent_id 的条目 text_content
+                for r in results:
+                    r["text_content"] = (r.get("text_content") or "") + extra
+
+        run_async(_enhance_all())
         logger.info(f"RAG检索返回了 {len(bfp_result_lists)} 个查询对结果")
         # 获取第一个查询对的第一个结果
         first_result_list = bfp_result_lists[0]
@@ -1051,6 +1140,4 @@ class AIReviewEngine(BaseReviewer):
                     "execution_time": execution_time,
                     "error_message": error_msg
                 }
-            }
-
-
+            }

+ 97 - 0
core/construction_review/component/infrastructure/milvus.py

@@ -0,0 +1,97 @@
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+from typing import Any, Dict, List, Optional, Sequence
+
+from pymilvus import MilvusClient
+from langchain_core.documents import Document
+
+from foundation.infrastructure.config.config import config_handler
+
+
+@dataclass(frozen=True)
+class MilvusConfig:
+    """
+    连接配置:uri / db_name 从配置读取
+    """
+    uri: str = field(
+        default_factory=lambda: (
+            f"http://{config_handler.get('milvus', 'MILVUS_HOST', 'localhost')}:"
+            f"{int(config_handler.get('milvus', 'MILVUS_PORT', '19530'))}"
+        )
+    )
+    db_name:str=config_handler.get('milvus', 'MILVUS_DB', 'lq_db') 
+
+
+class MilvusManager:
+    """
+    基于 pymilvus.MilvusClient 的管理类(不使用 langchain-milvus):
+    - 初始化:创建 client,并 use_database(db_name)
+    - 查询:每次传 collection_name(不固定)
+    - 提供:
+        1) condition_query:纯条件查询(MilvusClient.query)
+    """
+
+    def __init__(self, cfg: MilvusConfig):
+        self.cfg = cfg
+        self.client = MilvusClient(uri=self.cfg.uri)
+        self.client.use_database(self.cfg.db_name)
+
+        # 约定字段名(按你们 schema 调整)
+        self.text_field = "text"
+
+    def list_collections(self) -> List[str]:
+        return self.client.list_collections()
+
+    def condition_query(
+        self,
+        *,
+        collection_name: str,
+        filter: str,
+        output_fields: Optional[Sequence[str]] = None,
+        limit: Optional[int] = None,
+    ) -> List[Dict[str, Any]]:
+        """
+        filter 示例:
+          parent_id == 'xxx'
+          tenant == 't1' and source == 'pdf'
+
+        output_fields 示例:
+          ["text"]
+          ["text", "parent_id", "chunk_id"]
+        """
+        if not collection_name:
+            raise ValueError("collection_name 不能为空")
+
+        if output_fields is None:
+            output_fields = [self.text_field]
+
+        # 提前校验,避免直接抛 MilvusException 且不直观
+        if not self.client.has_collection(collection_name):
+            existing = self.client.list_collections()
+            raise RuntimeError(
+                f"collection not found: {collection_name}\n"
+                f"current db_name={self.cfg.db_name}, uri={self.cfg.uri}\n"
+                f"collections in current db: {existing}"
+            )
+
+        rows = self.client.query(
+            collection_name=collection_name,
+            filter=filter,
+            output_fields=list(output_fields),
+            limit=limit,
+        )
+
+        return rows
+
+
+if __name__ == "__main__":
+    mv = MilvusManager(MilvusConfig())
+
+    docs = mv.condition_query(
+        collection_name="rag_parent_hybrid",
+        filter="parent_id == '02267e1d-11d7-4a3d-b53f-e205edd6758f'",
+        limit=10,
+    )
+
+    print(docs)

+ 79 - 0
core/construction_review/component/infrastructure/relevance.py

@@ -0,0 +1,79 @@
+import asyncio
+import json
+import re
+import requests
+
+
+# ===============================
+# 1) 最小 async LLM 调用(等价 curl)
+# ===============================
+async def qwen_chat_async(prompt: str) -> str:
+    def _call():
+        url = "http://192.168.91.253:8003/v1/chat/completions"
+        headers = {
+            "Content-Type": "application/json",
+            "Authorization": "Bearer sk-123456",
+        }
+        payload = {
+            "model": "qwen3-30b",
+            "messages": [{"role": "user", "content": prompt}],
+        }
+        resp = requests.post(url, json=payload, headers=headers, timeout=60)
+        resp.raise_for_status()
+        return resp.json()["choices"][0]["message"]["content"]
+
+    loop = asyncio.get_running_loop()
+    return await loop.run_in_executor(None, _call)
+
+
+# ===============================
+# 2) 相关性判断提示词(只要 true/false)
+# ===============================
+def build_relevance_prompt(text_a: str, text_b: str) -> str:
+    return f"""
+你是信息检索与规范审查专家。
+
+任务:
+判断【文本B】是否与【文本A】强相关(可用于支撑审查引用/解释/依据)。
+
+强相关 = 文本B能直接支撑/解释/约束文本A的关键点(要求、条件、步骤、指标、术语定义等)。
+不相关 = 只是出现少量相似词,主题不同,无法支撑审查。
+
+输出要求(非常重要):
+- 只能输出严格 JSON
+- 不要任何解释文字
+- JSON 格式必须严格如下:
+{{"relevant": true}} 或 {{"relevant": false}}
+
+【文本A】:
+{text_a}
+
+【文本B】:
+{text_b}
+""".strip()
+
+
+# ===============================
+# 3) 对外函数:只返回 True / False
+# ===============================
+async def is_relevant_async(text_a: str, text_b: str) -> bool:
+    prompt = build_relevance_prompt(text_a, text_b)
+    out = await qwen_chat_async(prompt)
+
+    if not out:
+        return False
+
+    # 尝试解析 JSON
+    try:
+        obj = json.loads(out)
+    except Exception:
+        # 尝试从输出中提取 {...}
+        m = re.search(r"\{[\s\S]*\}", out)
+        if not m:
+            return False
+        try:
+            obj = json.loads(m.group(0))
+        except Exception:
+            return False
+
+    return bool(obj.get("relevant", False))

+ 1 - 1
foundation/ai/rag/retrieval/entities_enhance.py

@@ -41,7 +41,7 @@ class EntitiesEnhance():
             server_logger.info(f"bfp_result:{bfp_result}")
             self.bfp_result_lists.append(bfp_result)
             server_logger.info("实体增强召回结束")
-        self.test_file(self.bfp_result_lists,seve=True)
+        #self.test_file(self.bfp_result_lists,seve=True)
         return self.bfp_result_lists
             
 

+ 0 - 4
utils_test/AI_Review_Test/test_rag_enhanced_check.py

@@ -90,10 +90,6 @@ def test_rag_enhanced_check():
     # 创建AIReviewEngine实例
     review_engine = AIReviewEngine(task_file_info)
 
-    # 执行测试
-    print("\n[输入参数]")
-    print(f"  query_content: {query_content}")
-
     start_time = time.time()
     result = review_engine.rag_enhanced_check(unit_content)
     logger.info(f"rag_enhanced_check_result {result}")