Quellcode durchsuchen

v0.0.3-修复milvus入库脚本
- 增加json键值对中列表处理

WangXuMing vor 2 Monaten
Ursprung
Commit
0aff48cd83

+ 2 - 1
.gitignore

@@ -66,4 +66,5 @@ todo.md
 .claude
 .R&D
 temp/
-*.json
+*.json
+test_rawdata/

+ 2 - 2
config/config.ini

@@ -117,8 +117,8 @@ MILVUS_PASSWORD=
 
 [hybrid_search]
 # 混合检索权重配置
-DEFAULT_VECTOR_WEIGHT=0.7
-DEFAULT_BM25_WEIGHT=0.3
+DENSE_WEIGHT=0.7
+SPARSE_WEIGHT=0.3
 
 
 

+ 85 - 50
data_pipeline/milvus_inbound_script/milvus入库脚本.py

@@ -5,12 +5,64 @@
 
 import sys
 import os
+import json
 
 # 添加项目根目录到路径
-sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+# 解决模块导入问题:添加项目根目录到 Python 路径
+current_script_path = os.path.abspath(__file__)
+script_dir = os.path.dirname(current_script_path)
+project_root = os.path.abspath(os.path.join(script_dir, "../../"))
+if project_root not in sys.path:
+    sys.path.insert(0, project_root)
+
+def load_documents_from_file(json_path: str):
+    """
+    从单个 JSON 文件读取所有 chunks,生成:
+        [{'content': str, 'metadata': dict}, ...]
+    只保留 chunk 自身的 content 和 metadata,不混入文件级元数据
+    """
+    documents = []
+
+    if not os.path.isfile(json_path):
+        print(f"[WARN] JSON 文件不存在: {json_path}")
+        return documents
+
+    try:
+        with open(json_path, "r", encoding="utf-8") as f:
+            data = json.load(f)
+    except Exception as e:
+        print(f"[ERROR] 读取 JSON 文件失败: {json_path}, error: {e}")
+        return documents
+
+    chunks = data.get("chunks", [])
+    if not isinstance(chunks, list):
+        print(f"[WARN] 文件 {json_path} 中的 chunks 字段不是 list,跳过")
+        return documents
+
+    for idx, chunk in enumerate(chunks):
+        if not isinstance(chunk, dict):
+            continue
+
+        # 提取 content
+        content = chunk.get("content", "")
+        if not content or not str(content).strip():
+            # 空内容不入库
+            continue
+
+        # 只用 chunk 自己的 metadata
+        metadata = chunk.get("metadata", {}) or {}
+
+        # 如果你也想保留 chunk 的索引,可以打开下面这行
+        # metadata["chunk_index"] = idx
+
+        documents.append({
+            "content": content,
+            "metadata": metadata,
+        })
+
+    print(f"[INFO] 文件 {os.path.basename(json_path)} 提取出 {len(documents)} 条 chunk 文档")
+    return documents
 
-print("测试修复后的 Milvus 向量实现")
-print("=" * 50)
 
 def test_basic_functionality():
     """测试基本功能"""
@@ -28,61 +80,44 @@ def test_basic_functionality():
         vector = manager.text_to_vector(test_text)
         print(f"text_to_vector 测试成功,向量维度: {len(vector)}")
 
-        # 简单测试文档
-        test_documents = [
-            {
-                'content': '四川路桥建设集团专注于桥梁和隧道工程建设',
-                'metadata': {"source":""}
-                
-            },
-            {
-                'content': '高速公路桥梁建设技术包括预应力混凝土和钢结构',
-                'metadata': {'category': 'technology', 'type': 'highway'}
-            }
-        ]
-
-        collection_name = "first_bfp_collection"
+        # ====== 关键改动:从文件夹读取所有 JSON 文件,生成 documents ======
+        # 指定你的 JSON 文件夹路径
+        json_dir = "data_pipeline/test_rawdata"
+
+        all_documents = []
+
+        if not os.path.isdir(json_dir):
+            print(f"[ERROR] 目录不存在: {json_dir}")
+            return False
+
+        # 遍历文件夹下所有文件
+        for filename in os.listdir(json_dir):
+            # 只处理 .json 文件(如果你是其它后缀,改这里)
+            if not filename.lower().endswith(".json"):
+                continue
+
+            json_path = os.path.join(json_dir, filename)
+            docs = load_documents_from_file(json_path)
+            if docs:
+                all_documents.extend(docs)
+
+        print(f"[INFO] 总共从目录 {json_dir} 中解析出 {len(all_documents)} 条文档")
+
+        if not all_documents:
+            print("[ERROR] 未从任何文件中解析到文档,停止测试")
+            return False
+        # ====== 关键改动结束 ======
+
+        collection_name = "first_bfp_collection_test"
 
         print(f"\n测试 create_hybrid_collection 方法...")
         vectorstore = manager.create_hybrid_collection(
             collection_name=collection_name,
-            documents=test_documents
+            documents=all_documents  # ← 用目录里解析出的所有 documents
         )
         print("create_hybrid_collection 执行成功!")
         print(f"返回的 vectorstore 类型: {type(vectorstore)}")
 
-        # 等待索引创建完成
-        import time
-        time.sleep(5)
-
-        print(f"\n测试 hybrid_search 方法...")
-        param = {'collection_name': collection_name}
-
-        # 测试加权搜索
-        results = manager.hybrid_search(
-            param=param,
-            query_text="桥梁建设",
-            top_k=2,
-            ranker_type="weighted",
-            dense_weight=0.7,
-            sparse_weight=0.3
-        )
-        print(f"Hybrid search 执行成功,返回 {len(results)} 个结果")
-
-        for i, result in enumerate(results):
-            content = result.get('text_content', '')[:50]
-            print(f"  {i+1}. {content}...")
-
-        # 清理测试集合
-        print(f"\n清理测试集合...")
-        try:
-            from pymilvus import utility
-            if utility.has_collection(collection_name):
-                utility.drop_collection(collection_name)
-                print(f"成功清理集合: {collection_name}")
-        except Exception as e:
-            print(f"清理集合失败: {e}")
-
         return True
 
     except Exception as e:

+ 0 - 1
foundation/ai/models/__init__.py

@@ -11,7 +11,6 @@ __all__ = [
     "ModelHandler",
     "get_models",
     "model_handler",
-    "BaseApiPlatform",
     "rerank_model"
 
 ]

+ 12 - 7
foundation/ai/models/rerank_model.py

@@ -6,10 +6,8 @@
 用于调用BGE重排序模型进行文档重排序
 """
 import json
-import asyncio
 import requests
 from typing import List, Dict, Any
-from foundation.ai.models.model_handler import model_handler
 from foundation.infrastructure.config.config import config_handler
 from foundation.observability.logger.loggering import server_logger
 
@@ -22,21 +20,28 @@ class LqReranker:
     def __init__(self):
         self.api_url = config_handler.get('rerank_model', 'BGE_RERANKER_SERVER_RUL')
         self.model = config_handler.get('rerank_model', 'BGE_RERANKER_MODEL_ID')
-        self.top_k = config_handler.get('rerank_model', 'BGE_RERANKER_TOP_N')
+        # 确保top_k是整数类型,避免切片错误
+        self.top_k = int(config_handler.get('rerank_model', 'BGE_RERANKER_TOP_N', 5))
         
-    def bge_rerank(self,query: str, candidates: List[str]) -> List[Dict[str, Any]]:
+    def bge_rerank(self,query: str, candidates: List[str],top_k :int = None) -> List[Dict[str, Any]]:
         """
         执行重排序的全局函数
 
         Args:
             query: 查询文本
             candidates: 候选文档列表
-            top_k: 返回结果数量
+            top_k: 调用时chaurnum参数,默认为None
+
 
         Returns:
             List[Dict]: 重排序后的结果列表
         """
         try:
+            # self.top_k 是config.ini生产环境中实际使用的重排序数量,bge_rerank中的top_k,用于开发环境中快速效果调试
+            if not top_k:# 如果开发top_k未指定,则使用配置文件中的top_k
+                top_k = self.top_k
+            
+
             server_logger.info(f"开始执行重排序,查询: {query}, 候选文档数量: {len(candidates)}")
 
             # 构建重排序请求
@@ -62,7 +67,7 @@ class LqReranker:
                 server_logger.debug(f"API响应: {json.dumps(result, ensure_ascii=False)}")
 
                 if "results" in result:
-                    return result["results"][:self.top_k]
+                    return result["results"][:top_k]
                 else:
                     server_logger.warning(f"API响应格式异常: {result}")
                     return []
@@ -73,6 +78,6 @@ class LqReranker:
         except Exception as e:
             server_logger.error(f"执行重排序失败: {str(e)}")
             # 返回原始顺序作为fallback
-            return [{"text": doc, "score": "0.0"} for doc in candidates[:self.top_k]]
+            return [{"text": doc, "score": "0.0"} for doc in candidates[:top_k]]
 
 rerank_model = LqReranker()

+ 11 - 9
foundation/ai/rag/retrieval/retrieval.py

@@ -1,11 +1,11 @@
 
 
 
-from foundation.database.base.vector.milvus_vector import MilvusVectorManager
-from foundation.observability.logger.loggering import server_logger
-from foundation.ai.models.rerank_model import rerank_model
 from typing import List, Dict, Any, Optional
-
+from foundation.ai.models.rerank_model import rerank_model
+from foundation.infrastructure.config.config import config_handler
+from foundation.observability.logger.loggering import server_logger
+from foundation.database.base.vector.milvus_vector import MilvusVectorManager
 
 class RetrievalManager:
     """
@@ -18,6 +18,8 @@ class RetrievalManager:
         """
         self.vector_manager = MilvusVectorManager()
         self.logger = server_logger
+        self.dense_weight = config_handler.get('hybrid_search', 'DENSE_WEIGHT', 0.7)
+        self.sparse_weight = config_handler.get('hybrid_search', 'SPARSE_WEIGHT', 0.3)
 
     def hybrid_search_recall(self, collection_name: str, query_text: str,
                            top_k: int = 10, ranker_type: str = "weighted",
@@ -57,7 +59,7 @@ class RetrievalManager:
             return []
 
     def rerank_recall(self, candidates: List[str], query_text: str,
-                     top_k: int = 10) -> List[Dict[str, Any]]:
+                  top_k: int = None  ) -> List[Dict[str, Any]]:
         """
         重排序召回 - 使用BGE重排序模型对候选文档重新排序
 
@@ -73,7 +75,7 @@ class RetrievalManager:
             self.logger.info(f"开始重排序召回,候选文档数量: {len(candidates)}")
 
             # 调用重排序执行器
-            rerank_results = rerank_model.bge_rerank(query_text, candidates)
+            rerank_results = rerank_model.bge_rerank(query_text, candidates, top_k)
 
             # 转换结果格式
             scored_docs = []
@@ -92,7 +94,7 @@ class RetrievalManager:
             return []
 
     def multi_stage_recall(self, collection_name: str, query_text: str,
-                          hybrid_top_k: int = 50, final_top_k: int = 10,
+                          hybrid_top_k: int = 50, top_k: int = 10,
                           ranker_type: str = "weighted") -> List[Dict[str, Any]]:
         """
         多路召回 - 先混合搜索召回,再重排序,只返回重排序结果
@@ -101,7 +103,7 @@ class RetrievalManager:
             collection_name: 集合名称
             query_text: 查询文本
             hybrid_top_k: 混合搜索召回的文档数量
-            final_top_k: 最终返回的文档数量
+            top_k: 最终返回的文档数量
             ranker_type: 混合搜索的重排序类型
 
         Returns:
@@ -129,7 +131,7 @@ class RetrievalManager:
             rerank_results = self.rerank_recall(
                 candidates=candidates,
                 query_text=query_text,
-                top_k=final_top_k
+                top_k=top_k
             )
 
             # 为重排序结果添加混合搜索的原始元数据

+ 14 - 4
foundation/database/base/vector/milvus_vector.py

@@ -369,13 +369,13 @@ class MilvusVectorManager(BaseVectorDB):
             if self.password:
                 connection_args["password"] = self.password
 
-            # 转换为 LangChain Document 格式 (参考 test_hybrid_v2.6.py)
+            
             langchain_docs = []
             for doc in documents:
                 content = doc.get('content', '')
                 metadata = doc.get('metadata', {})
-
-                langchain_doc = Document(page_content=content, metadata=metadata)
+                processed_metadata = self._process_metadata(doc)
+                langchain_doc = Document(page_content=content, metadata=processed_metadata)
                 langchain_docs.append(langchain_doc)
 
             # 创建混合搜索向量存储 (完全按照 test_hybrid_v2.6.py 的逻辑)
@@ -475,4 +475,14 @@ class MilvusVectorManager(BaseVectorDB):
             return self.similarity_search(param, query_text, top_k=top_k)
 
 
- 
+    def _process_metadata(self,metadata):
+        """处理 metadata:将 list 类型的 hierarchy 转换为 Milvus 支持的 string 类型"""
+        processed_metadata = metadata.copy()
+        if "hierarchy" in processed_metadata and isinstance(processed_metadata["hierarchy"], list):
+            processed_metadata["hierarchy"] = " > ".join(processed_metadata["hierarchy"])
+        for key, value in processed_metadata.items():
+            if value is None:
+                processed_metadata[key] = ""
+            elif isinstance(value, dict):
+                processed_metadata[key] = json.dumps(value, ensure_ascii=False)
+        return processed_metadata

+ 6 - 4
test/test_multi_stage_recall.py

@@ -215,7 +215,6 @@ def test_multi_stage_recall(collection_name):
                 collection_name=collection_name,
                 query_text=query,
                 hybrid_top_k=10,
-                final_top_k=5
             )
 
             end_time = time.time()
@@ -241,7 +240,10 @@ def test_multi_stage_recall(collection_name):
             print(f"\n  召回摘要:")
             print(f"  - 总数量: {len(results)}")
             print(f"  - 平均重排序分数: {avg_rerank_score:.4f}")
-            print(f"  - 重排序分数范围: {min(rerank_scores):.4f} - {max(rerank_scores):.4f}")
+            if rerank_scores:
+                print(f"  - 重排序分数范围: {min(rerank_scores):.4f} - {max(rerank_scores):.4f}")
+            else:
+                print(f"  - 重排序分数范围: 无数据")
 
             all_results.extend(results)
 
@@ -289,7 +291,6 @@ def test_different_parameters(collection_name):
                 collection_name=collection_name,
                 query_text=query,
                 hybrid_top_k=config['hybrid_top_k'],
-                final_top_k=config['final_top_k']
             )
 
             end_time = time.time()
@@ -303,7 +304,8 @@ def test_different_parameters(collection_name):
                 similarities = [r.get('similarity', 0) for r in results]
                 rerank_scores = [r.get('rerank_score', 0) for r in results]
 
-                print(f"✓ 相似度范围: {min(similarities):.4f} - {max(similarities):.4f}")
+                if similarities:
+                    print(f"✓ 相似度范围: {min(similarities):.4f} - {max(similarities):.4f}")
                 if rerank_scores:
                     print(f"✓ 重排序分数范围: {min(rerank_scores):.4f} - {max(rerank_scores):.4f}")