Эх сурвалжийг харах

dev:milvus添加混合检索

ZengChao 1 сар өмнө
parent
commit
abd5e482e1

+ 5 - 1
requirements/base.txt

@@ -44,4 +44,8 @@ flower==2.0.1
 python-dotenv==1.0.0
 
 # 向量数据库
-pymilvus==2.6.6
+pymilvus==2.6.6
+
+# LangChain相关
+langchain-openai==1.1.7
+langchain-milvus==0.3.3

+ 6 - 0
src/app/base/__init__.py

@@ -14,10 +14,12 @@ from .async_redis_connection import get_redis_connection, init_redis, close_redi
 from .milvus_connection import (
     get_milvus_connection, 
     get_milvus_manager,
+    get_milvus_vectorstore,
     MilvusManager,
     init_milvus, 
     close_milvus
 )
+from .embedding_connection import get_embedding_model, get_embedding_config
 
 __all__ = [
     # MySQL
@@ -35,7 +37,11 @@ __all__ = [
     # Milvus
     "get_milvus_connection",
     "get_milvus_manager",
+    "get_milvus_vectorstore",
     "MilvusManager",
     "init_milvus",
     "close_milvus",
+    # Embedding
+    "get_embedding_model",
+    "get_embedding_config",
 ]

+ 35 - 0
src/app/base/embedding_connection.py

@@ -0,0 +1,35 @@
+"""
+Embedding模型异步连接管理
+"""
+import os
+import logging
+from typing import Optional
+from langchain_openai import OpenAIEmbeddings
+
+# 导入配置
+from app.core.config import config_handler
+
+logger = logging.getLogger(__name__)
+
+
+
+def get_embedding_config() -> dict:
+    """获取Embedding配置"""
+    config = {
+        'base_url': config_handler.get("admin_app", "EMBEDDING_BASE_URL", "http://192.168.91.253:9003/v1"),
+        'model': config_handler.get("admin_app", "EMBEDDING_MODEL", "Qwen3-Embedding-8B"),
+        'api_key': config_handler.get("admin_app", "EMBEDDING_API_KEY", "dummy"),
+        'timeout': 30
+    }
+    return config
+
+
+def get_embedding_model():
+    """获取Embedding模型客户端"""
+    config = get_embedding_config()
+    embedding_client = OpenAIEmbeddings(
+        base_url=config['base_url'],
+        model=config['model'],
+        api_key=config['api_key'],  # 本地模型使用虚拟API key
+    )
+    return embedding_client

+ 42 - 0
src/app/base/milvus_connection.py

@@ -7,12 +7,54 @@ from typing import Optional
 
 # 导入配置
 from app.core.config import config_handler
+from .embedding_connection import get_embedding_model
+
+from langchain_milvus import Milvus, BM25BuiltInFunction
 
 logger = logging.getLogger(__name__)
 
 _milvus_manager = None
 
 
+def get_milvus_vectorstore(collection_name: str, consistency_level: str = "Strong"):
+    """
+    获取 Milvus Vectorstore 实例(用于混合搜索)
+    
+    Args:
+        collection_name: 集合名称
+        consistency_level: 一致性级别,默认为 "Strong"
+    
+    Returns:
+        Milvus: LangChain 的 Milvus Vectorstore 实例
+    """
+    try:
+        # 直接调用embedding_connection的embedding
+        embedding_function = get_embedding_model()
+        
+        manager = get_milvus_manager()
+        connection_args = {
+            "uri": f"http://{manager.host}:{manager.port}",
+            "user": manager.user,
+            "db_name": manager.db_name
+        }
+        
+        if manager.password:
+            connection_args["password"] = manager.password
+        
+        vectorstore = Milvus(
+            embedding_function=embedding_function,
+            collection_name=collection_name,
+            connection_args=connection_args,
+            consistency_level=consistency_level,
+            builtin_function=BM25BuiltInFunction(),
+            vector_field=["dense", "sparse"]
+        )
+        return vectorstore
+    except Exception as e:
+        logger.error(f"获取 Milvus Vectorstore 失败: {e}")
+        raise
+
+
 class MilvusManager:
     """Milvus管理器"""
     

+ 7 - 1
src/app/config/config.ini

@@ -76,4 +76,10 @@ SESSION_TTL=86400
 
 # Celery配置
 CELERY_BROKER_URL=redis://localhost:6379/1
-CELERY_RESULT_BACKEND=redis://localhost:6379/2
+CELERY_RESULT_BACKEND=redis://localhost:6379/2
+
+
+# embedding模型配置
+EMBEDDING_BASE_URL=http://192.168.91.253:9003/v1
+EMBEDDING_MODEL=Qwen3-Embedding-8B
+EMBEDDING_API_KEY=dummy

+ 123 - 2
src/app/services/milvus_service.py

@@ -14,7 +14,7 @@ import logging
 from typing import List, Dict, Any
 from datetime import datetime
 
-from app.base import get_milvus_manager
+from app.base import get_milvus_manager, get_milvus_vectorstore, get_embedding_model
 
 logger = logging.getLogger(__name__)
 
@@ -22,6 +22,8 @@ logger = logging.getLogger(__name__)
 class MilvusService:
     def __init__(self):
         self.client = get_milvus_manager().client
+        # 获取embedding model
+        self.emdmodel = get_embedding_model()
 
     def create_collection(self, name: str, dimension: int = 768, description: str = "") -> None:
         """创建 Milvus 集合"""
@@ -256,6 +258,67 @@ class MilvusService:
             "updated_time": updated_time,
         }
 
+    def hybrid_search(self, collection_name: str, query_text: str,
+                     top_k: int = 3, ranker_type: str = "weighted",
+                     dense_weight: float = 0.7, sparse_weight: float = 0.3):
+        """
+        混合搜索(参考 test_hybrid_v2.6.py 的实现)
+
+        Args:
+            param: 包含collection_name的参数字典
+            query_text: 查询文本
+            top_k: 返回结果数量
+            ranker_type: 重排序类型 "weighted" 或 "rrf"
+            dense_weight: 密集向量权重(当ranker_type="weighted"时使用)
+            sparse_weight: 稀疏向量权重(当ranker_type="weighted"时使用)
+
+        Returns:
+            List[Dict]: 搜索结果列表
+        """
+        try:
+            collection_name = collection_name
+
+            # 获取 vectorstore 实例(包含 Milvus 和 BM25BuiltInFunction)
+            vectorstore = get_milvus_vectorstore(
+                collection_name=collection_name,
+                consistency_level="Strong"
+            )
+
+            # 执行混合搜索 (完全按照 test_hybrid_v2.6.py 的逻辑)
+            if ranker_type == "weighted":
+                results = vectorstore.similarity_search(
+                    query=query_text,
+                    k=top_k,
+                    ranker_type="weighted",
+                    ranker_params={"weights": [dense_weight, sparse_weight]}
+                )
+            else:  # rrf
+                results = vectorstore.similarity_search(
+                    query=query_text,
+                    k=top_k,
+                    ranker_type="rrf",
+                    ranker_params={"k": 60}
+                )
+
+            # 格式化结果,保持与其他搜索方法一致
+            formatted_results = []
+            for doc in results:
+                formatted_results.append({
+                    'id': doc.metadata.get('pk', 0),
+                    'text_content': doc.page_content,
+                    'metadata': doc.metadata,
+                    'distance': 0.0,
+                    'similarity': 1.0
+                })
+
+            logger.info(f"Hybrid search returned {len(formatted_results)} results")
+            return formatted_results
+
+        except Exception as e:
+            logger.error(f"Error in hybrid search: {e}")
+            # 回退到传统的向量搜索
+            logger.info("Falling back to traditional vector search")
+
 
 # 可选:单例
 milvus_service = MilvusService()
@@ -266,6 +329,64 @@ if __name__ == "__main__":
     # uv run python -m src.app.services.milvus_service
     import json
 
-    data = MilvusService().get_collection_details()
+    service = MilvusService()
+    
+    # 测试混合搜索 hybrid_search
+    print("=" * 50)
+    print("测试混合检索 (Hybrid Search)")
+    print("=" * 50)
+    
+    try:
+        # 示例参数,需要根据实际情况修改
+        collection_name = "first_bfp_collection_status" 
+        query_text = "《公路水运工程临时用电技术规程》(JTT1499-2024)状态为现行"  # 修改为实际查询内容
+        
+        # 测试 weighted 模式
+        print("\n1. 测试 Weighted 重排序模式:")
+        print(f"   集合: {collection_name}")
+        print(f"   查询: {query_text}")
+        print(f"   密集权重: 0.7, 稀疏权重: 0.3")
+        
+        results_weighted = service.hybrid_search(
+            collection_name=collection_name,
+            query_text=query_text,
+            top_k=5,
+            ranker_type="weighted",
+            dense_weight=0.7,
+            sparse_weight=0.3
+        )
+        
+        print(f"\n   结果数量: {len(results_weighted)}")
+        for i, result in enumerate(results_weighted, 1):
+            print(f"   [{i}] ID: {result.get('id')}, Text: {result.get('text_content')[:50]}...")
+        
+        # 测试 RRF 模式
+        print("\n2. 测试 RRF (Reciprocal Rank Fusion) 重排序模式:")
+        print(f"   集合: {collection_name}")
+        print(f"   查询: {query_text}")
+        
+        results_rrf = service.hybrid_search(
+            collection_name=collection_name,
+            query_text=query_text,
+            top_k=5,
+            ranker_type="rrf"
+        )
+        
+        print(f"\n   结果数量: {len(results_rrf)}")
+        for i, result in enumerate(results_rrf, 1):
+            print(f"   [{i}] ID: {result.get('id')}, Text: {result.get('text_content')[:50]}...")
+        
+        print("\n✓ 混合检索测试完成")
+        
+    except Exception as e:
+        print(f"\n✗ 混合检索测试失败: {e}")
+        import traceback
+        traceback.print_exc()
+    
+    # 也可以查看集合详情
+    print("\n" + "=" * 50)
+    print("获取所有集合信息:")
+    print("=" * 50)
+    data = service.get_collection_details()
     for item in data:
         print(json.dumps(item, ensure_ascii=False, indent=2))