|
|
@@ -14,7 +14,7 @@ import logging
|
|
|
from typing import List, Dict, Any
|
|
|
from datetime import datetime
|
|
|
|
|
|
-from app.base import get_milvus_manager
|
|
|
+from app.base import get_milvus_manager, get_milvus_vectorstore, get_embedding_model
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@@ -22,6 +22,8 @@ logger = logging.getLogger(__name__)
|
|
|
class MilvusService:
|
|
|
def __init__(self):
|
|
|
self.client = get_milvus_manager().client
|
|
|
+ # 获取embedding model
|
|
|
+ self.emdmodel = get_embedding_model()
|
|
|
|
|
|
def create_collection(self, name: str, dimension: int = 768, description: str = "") -> None:
|
|
|
"""创建 Milvus 集合"""
|
|
|
@@ -256,6 +258,67 @@ class MilvusService:
|
|
|
"updated_time": updated_time,
|
|
|
}
|
|
|
|
|
|
+ def hybrid_search(self, collection_name: str, query_text: str,
|
|
|
+ top_k: int = 3, ranker_type: str = "weighted",
|
|
|
+ dense_weight: float = 0.7, sparse_weight: float = 0.3):
|
|
|
+ """
|
|
|
+ 混合搜索(参考 test_hybrid_v2.6.py 的实现)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ param: 包含collection_name的参数字典
|
|
|
+ query_text: 查询文本
|
|
|
+ top_k: 返回结果数量
|
|
|
+ ranker_type: 重排序类型 "weighted" 或 "rrf"
|
|
|
+ dense_weight: 密集向量权重(当ranker_type="weighted"时使用)
|
|
|
+ sparse_weight: 稀疏向量权重(当ranker_type="weighted"时使用)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ List[Dict]: 搜索结果列表
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ collection_name = collection_name
|
|
|
+
|
|
|
+ # 获取 vectorstore 实例(包含 Milvus 和 BM25BuiltInFunction)
|
|
|
+ vectorstore = get_milvus_vectorstore(
|
|
|
+ collection_name=collection_name,
|
|
|
+ consistency_level="Strong"
|
|
|
+ )
|
|
|
+
|
|
|
+ # 执行混合搜索 (完全按照 test_hybrid_v2.6.py 的逻辑)
|
|
|
+ if ranker_type == "weighted":
|
|
|
+ results = vectorstore.similarity_search(
|
|
|
+ query=query_text,
|
|
|
+ k=top_k,
|
|
|
+ ranker_type="weighted",
|
|
|
+ ranker_params={"weights": [dense_weight, sparse_weight]}
|
|
|
+ )
|
|
|
+ else: # rrf
|
|
|
+ results = vectorstore.similarity_search(
|
|
|
+ query=query_text,
|
|
|
+ k=top_k,
|
|
|
+ ranker_type="rrf",
|
|
|
+ ranker_params={"k": 60}
|
|
|
+ )
|
|
|
+
|
|
|
+ # 格式化结果,保持与其他搜索方法一致
|
|
|
+ formatted_results = []
|
|
|
+ for doc in results:
|
|
|
+ formatted_results.append({
|
|
|
+ 'id': doc.metadata.get('pk', 0),
|
|
|
+ 'text_content': doc.page_content,
|
|
|
+ 'metadata': doc.metadata,
|
|
|
+ 'distance': 0.0,
|
|
|
+ 'similarity': 1.0
|
|
|
+ })
|
|
|
+
|
|
|
+ logger.info(f"Hybrid search returned {len(formatted_results)} results")
|
|
|
+ return formatted_results
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error in hybrid search: {e}")
|
|
|
+ # 回退到传统的向量搜索
|
|
|
+ logger.info("Falling back to traditional vector search")
|
|
|
+
|
|
|
|
|
|
# 可选:单例
|
|
|
milvus_service = MilvusService()
|
|
|
@@ -266,6 +329,64 @@ if __name__ == "__main__":
|
|
|
# uv run python -m src.app.services.milvus_service
|
|
|
import json
|
|
|
|
|
|
- data = MilvusService().get_collection_details()
|
|
|
+ service = MilvusService()
|
|
|
+
|
|
|
+ # 测试混合搜索 hybrid_search
|
|
|
+ print("=" * 50)
|
|
|
+ print("测试混合检索 (Hybrid Search)")
|
|
|
+ print("=" * 50)
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 示例参数,需要根据实际情况修改
|
|
|
+ collection_name = "first_bfp_collection_status"
|
|
|
+ query_text = "《公路水运工程临时用电技术规程》(JTT1499-2024)状态为现行" # 修改为实际查询内容
|
|
|
+
|
|
|
+ # 测试 weighted 模式
|
|
|
+ print("\n1. 测试 Weighted 重排序模式:")
|
|
|
+ print(f" 集合: {collection_name}")
|
|
|
+ print(f" 查询: {query_text}")
|
|
|
+ print(f" 密集权重: 0.7, 稀疏权重: 0.3")
|
|
|
+
|
|
|
+ results_weighted = service.hybrid_search(
|
|
|
+ collection_name=collection_name,
|
|
|
+ query_text=query_text,
|
|
|
+ top_k=5,
|
|
|
+ ranker_type="weighted",
|
|
|
+ dense_weight=0.7,
|
|
|
+ sparse_weight=0.3
|
|
|
+ )
|
|
|
+
|
|
|
+ print(f"\n 结果数量: {len(results_weighted)}")
|
|
|
+ for i, result in enumerate(results_weighted, 1):
|
|
|
+ print(f" [{i}] ID: {result.get('id')}, Text: {result.get('text_content')[:50]}...")
|
|
|
+
|
|
|
+ # 测试 RRF 模式
|
|
|
+ print("\n2. 测试 RRF (Reciprocal Rank Fusion) 重排序模式:")
|
|
|
+ print(f" 集合: {collection_name}")
|
|
|
+ print(f" 查询: {query_text}")
|
|
|
+
|
|
|
+ results_rrf = service.hybrid_search(
|
|
|
+ collection_name=collection_name,
|
|
|
+ query_text=query_text,
|
|
|
+ top_k=5,
|
|
|
+ ranker_type="rrf"
|
|
|
+ )
|
|
|
+
|
|
|
+ print(f"\n 结果数量: {len(results_rrf)}")
|
|
|
+ for i, result in enumerate(results_rrf, 1):
|
|
|
+ print(f" [{i}] ID: {result.get('id')}, Text: {result.get('text_content')[:50]}...")
|
|
|
+
|
|
|
+ print("\n✓ 混合检索测试完成")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"\n✗ 混合检索测试失败: {e}")
|
|
|
+ import traceback
|
|
|
+ traceback.print_exc()
|
|
|
+
|
|
|
+ # 也可以查看集合详情
|
|
|
+ print("\n" + "=" * 50)
|
|
|
+ print("获取所有集合信息:")
|
|
|
+ print("=" * 50)
|
|
|
+ data = service.get_collection_details()
|
|
|
for item in data:
|
|
|
print(json.dumps(item, ensure_ascii=False, indent=2))
|