import time from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility, Function from pymilvus.client.types import FunctionType from pymilvus import AnnSearchRequest, RRFRanker, WeightedRanker # from sentence_transformers import SentenceTransformer import numpy as np from typing import List, Dict, Any, Optional import json # 导入 LangChain Milvus 混合搜索相关包 from langchain_milvus import Milvus, BM25BuiltInFunction from langchain_core.documents import Document from langchain_core.embeddings import Embeddings from foundation.infrastructure.config.config import config_handler from foundation.database.base.vector.base_vector import BaseVectorDB # 延迟导入logger和model_handler以避免循环依赖 logger = None model_handler = None def _get_logger(): """延迟导入logger以避免循环依赖""" global logger if logger is None: try: from foundation.observability.logger.loggering import server_logger logger = server_logger except ImportError: # 如果导入失败,创建一个简单的logger替代品 import logging logger = logging.getLogger(__name__) return logger def _get_model_handler(): """延迟导入model_handler以避免循环依赖""" global model_handler if model_handler is None: try: from foundation.ai.models.model_handler import model_handler as mh model_handler = mh except ImportError: # 如果导入失败,返回None model_handler = None return model_handler class MilvusVectorManager(BaseVectorDB): def __init__(self): """ 初始化 Milvus 连接 """ self.host = config_handler.get('milvus', 'MILVUS_HOST', 'localhost') self.port = int(config_handler.get('milvus', 'MILVUS_PORT', '19530')) self.milvus_db = config_handler.get('milvus', 'MILVUS_DB', 'default') self.user = config_handler.get('milvus', 'MILVUS_USER') self.password = config_handler.get('milvus', 'MILVUS_PASSWORD') # 初始化文本向量化模型 mh = _get_model_handler() if mh: self.emdmodel = mh.get_embedding_model() else: raise ImportError("无法导入model_handler,无法初始化嵌入模型") # 缓存连接参数 self.connection_args = { "uri": f"http://{self.host}:{self.port}", "user": self.user, "db_name": "lq_db" } if self.password: self.connection_args["password"] = self.password # 连接到 Milvus self.connect() # 预创建常用的vectorstore连接,避免运行时竞争 self._vectorstore_cache = {} self._create_common_connections() def _create_common_connections(self): """预创建常用的vectorstore连接""" common_collections = [ "first_bfp_collection_entity", "first_bfp_collection_test" ] for collection_name in common_collections: try: _get_logger().info(f"预创建vectorstore连接: {collection_name}") self._vectorstore_cache[collection_name] = Milvus( embedding_function=self.emdmodel, collection_name=collection_name, connection_args=self.connection_args, consistency_level="Strong", builtin_function=BM25BuiltInFunction(), vector_field=["dense", "sparse"] ) _get_logger().info(f"成功预创建连接: {collection_name}") except Exception as e: _get_logger().error(f"预创建连接失败 {collection_name}: {e}") def text_to_vector(self, text: str) -> List[float]: """ 将文本转换为向量(重写基类方法,直接使用嵌入模型) """ try: # 使用已有的嵌入模型 embedding = self.emdmodel.embed_query(text) return embedding.tolist() if hasattr(embedding, 'tolist') else list(embedding) except Exception as e: _get_logger().error(f"Error converting text to vector: {e}") raise def connect(self): """连接到 Milvus 服务器 , password=self.password alias="default", """ try: connections.connect( alias="default", host=self.host, port=self.port, user=self.user, db_name="lq_db" ) _get_logger().info(f"Connected to Milvus at {self.host}:{self.port}") except Exception as e: _get_logger().error(f"Failed to connect to Milvus: {e}") raise def create_collection(self, collection_name: str, dimension: int = 768, description: str = "Vector collection for text embeddings"): """ 创建向量集合 """ try: # 检查集合是否已存在 if utility.has_collection(collection_name): _get_logger().info(f"Collection {collection_name} already exists") utility.drop_collection(collection_name) _get_logger().info(f"Collection '{collection_name}' dropped successfully") # 定义字段 fields = [ FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension), FieldSchema(name="text_content", dtype=DataType.VARCHAR, max_length=65535), FieldSchema(name="metadata", dtype=DataType.JSON), FieldSchema(name="created_at", dtype=DataType.INT64) ] # 创建集合模式 schema = CollectionSchema( fields=fields, description=description ) # 创建集合 collection = Collection( name=collection_name, schema=schema ) # 创建索引 index_params = { "index_type": "IVF_FLAT", "metric_type": "COSINE", "params": {"nlist": 100} } collection.create_index(field_name="vector", index_params=index_params) _get_logger().info(f"Collection {collection_name} created successfully!") except Exception as e: _get_logger().error(f"Error creating collection: {e}") raise def add_document(self , param: Dict[str, Any] , document: Dict[str, Any]): """ 插入单个文本及其向量 """ try: collection_name = param.get('collection_name') text = document.get('content') metadata = document.get('metadata') collection = Collection(collection_name) created_at = None # 转换文本为向量 embedding = self.text_to_vector(text) #_get_logger().info(f"Text converted to embedding:{isinstance(embedding, list)} ,{len(embedding)}") #_get_logger().info(f"Text converted to embedding:{embedding}") # 准备数据 data = [ [embedding], # embedding [text], # text_content [metadata or {}], # metadata [created_at or int(time.time())] # created_at ] _get_logger().info(f"Preparing to insert text_contents:{len(data[0])} ,{len(data[1])},{len(data[2])},{len(data[3])}") # 插入数据 insert_result = collection.insert(data) collection.flush() # 确保数据被写入 _get_logger().info(f"Text inserted with ID: {insert_result.primary_keys[0]}") return insert_result.primary_keys[0] except Exception as e: _get_logger().error(f"Error inserting text: {e}") return None def add_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]]): """ 批量插入文本 texts: [{'text': '...', 'metadata': {...}}, ...] """ try: collection_name = param.get('collection_name') collection = Collection(collection_name) text_contents = [] embeddings = [] metadatas = [] timestamps = [] for item in documents: text = item['content'] metadata = item.get('metadata', {}) # 转换文本为向量 embedding = self.text_to_vector(text) text_contents.append(text) embeddings.append(embedding) metadatas.append(metadata) timestamps.append(int(time.time())) # 准备批量数据 data = [embeddings, text_contents, metadatas, timestamps] #_get_logger().info(f"Preparing to insert text_contents:{len(text_contents)} ,{len(embeddings)},{len(metadatas)},{len(timestamps)}") # 批量插入 insert_result = collection.insert(data) collection.flush() # 确保数据被写入 _get_logger().info(f"Batch inserted {len(text_contents)} records, IDs: {insert_result.primary_keys}") return insert_result.primary_keys except Exception as e: _get_logger().error(f"Error batch inserting: {e}") return None def similarity_search(self, param: Dict[str, Any], query_text: str , min_score=0.5 , top_k=5, filters: Dict[str, Any] = None): """ 搜索相似文本 """ try: collection_name = param.get('collection_name') collection = Collection(collection_name) # 加载集合到内存(如果还没有加载) collection.load() # 转换查询文本为向量 query_embedding = self.text_to_vector(query_text) # 搜索参数 search_params = { "metric_type": "COSINE", "params": {"nprobe": 10} } # 构建过滤表达式 filter_expr = self._create_filter(filters) # 执行搜索 results = collection.search( data=[query_embedding], anns_field="vector", param=search_params, limit=top_k, expr=filter_expr, output_fields=["text", "metadata"] ) # 格式化结果 formatted_results = [] for hits in results: for hit in hits: formatted_results.append({ 'id': hit.id, 'text_content': hit.entity.get('text'), 'text': hit.entity.get('text'), # 添加 text 字段以兼容现有代码 'metadata': hit.entity.get('metadata'), 'distance': hit.distance, 'similarity': 1 - hit.distance # 转换为相似度 }) return formatted_results except Exception as e: _get_logger().error(f"Error searching: {e}") return [] def retriever(self, param: Dict[str, Any], query_text: str, top_k: int = 5, filters: Dict[str, Any] = None): """ 带过滤条件的相似搜索 """ try: collection_name = param.get('collection_name') collection = Collection(collection_name) collection.load() query_embedding = self.text_to_vector(query_text) # 构建过滤表达式 filter_expr = self._create_filter(filters) search_params = { "metric_type": "COSINE", "params": {"nprobe": 10} } results = collection.search( data=[query_embedding], anns_field="vector", param=search_params, limit=top_k, expr=filter_expr, output_fields=["text", "metadata"] ) formatted_results = [] for hits in results: for hit in hits: formatted_results.append({ 'id': hit.id, 'text_content': hit.entity.get('text_content'), 'metadata': hit.entity.get('metadata'), 'distance': hit.distance, 'similarity': 1 - hit.distance }) return formatted_results except Exception as e: _get_logger().error(f"Error searching with filter: {e}") return [] def _create_filter(self, filters: Dict[str, Any]) -> str: """ 创建过滤条件 """ # 构建过滤表达式 filter_expr = "" if filters: conditions = [] for key, value in filters.items(): if isinstance(value, str): conditions.append(f'metadata["{key}"] == "{value}"') elif isinstance(value, (int, float)): conditions.append(f'metadata["{key}"] == {value}') else: conditions.append(f'metadata["{key}"] == "{json.dumps(value)}"') filter_expr = " and ".join(conditions) return filter_expr def create_hybrid_collection(self, collection_name: str, documents: List[Dict[str, Any]]): """ 创建支持混合搜索的集合 Args: collection_name: 集合名称 documents: 文档列表,格式: [{'content': '...', 'metadata': {...}}, ...] """ try: # 构建连接参数 (参考 test_hybrid_v2.6.py) connection_args = { "uri": f"http://{self.host}:{self.port}", "user": self.user, "db_name": "lq_db" } if self.password: connection_args["password"] = self.password langchain_docs = [] for doc in documents: content = doc.get('content', '') metadata = doc.get('metadata', {}) processed_metadata = self._process_metadata(doc) langchain_doc = Document(page_content=content, metadata=processed_metadata) langchain_docs.append(langchain_doc) # 创建混合搜索向量存储 (完全按照 test_hybrid_v2.6.py 的逻辑) vectorstore = Milvus.from_documents( documents=langchain_docs, embedding=self.emdmodel, builtin_function=BM25BuiltInFunction(), vector_field=["dense", "sparse"], connection_args=connection_args, collection_name=collection_name, consistency_level="Strong", drop_old=True, ) _get_logger().info(f"Created hybrid collection: {collection_name} with {len(documents)} documents") return vectorstore except Exception as e: _get_logger().error(f"Error creating hybrid collection: {e}") _get_logger().info("Falling back to traditional vector search") return None def hybrid_search(self, param: Dict[str, Any], query_text: str, top_k: int , ranker_type: str = "weighted", dense_weight: float = 0.7, sparse_weight: float = 0.3): """ 混合搜索(参考 test_hybrid_v2.6.py 的实现) Args: param: 包含collection_name的参数字典 query_text: 查询文本 top_k: 返回结果数量 ranker_type: 重排序类型 "weighted" 或 "rrf" dense_weight: 密集向量权重(当ranker_type="weighted"时使用) sparse_weight: 稀疏向量权重(当ranker_type="weighted"时使用) Returns: List[Dict]: 搜索结果列表 """ try: collection_name = param.get('collection_name') logger.info(f"开始 hybrid_search, collection_name: {collection_name}") # 使用预创建的连接,避免运行时竞争 if collection_name in self._vectorstore_cache: vectorstore = self._vectorstore_cache[collection_name] else: # 如果缓存中没有,创建新连接(降级方案) _get_logger().warning(f"缓存中未找到连接: {collection_name},创建新连接") vectorstore = Milvus( embedding_function=self.emdmodel, collection_name=collection_name, connection_args=self.connection_args, consistency_level="Strong", builtin_function=BM25BuiltInFunction(), vector_field=["dense", "sparse"] ) # 缓存新创建的连接 self._vectorstore_cache[collection_name] = vectorstore _get_logger().info(f"混合召回topk: {top_k}") # 执行混合搜索,使用 similarity_search_with_score 获取评分 if ranker_type == "weighted": results_with_scores = vectorstore.similarity_search_with_score( query=query_text, k=top_k, ranker_type="weighted", ranker_params={"weights": [dense_weight, sparse_weight]} ) else: # rrf results_with_scores = vectorstore.similarity_search_with_score( query=query_text, k=top_k, ranker_type="rrf", ranker_params={"k": 60} ) # 格式化结果,保持与其他搜索方法一致 formatted_results = [] for doc, score in results_with_scores: # score 值越小表示相似度越高,所以 similarity = 1 / (1 + score) # 或者使用其他转换方式,这里使用简单的转换 similarity = 1 / (1 + score) if score >= 0 else 0 formatted_results.append({ 'id': doc.metadata.get('pk', 0), 'text_content': doc.page_content, 'metadata': doc.metadata, 'distance': float(score), # 使用真实的距离/评分 'similarity': float(similarity) # 转换为相似度 }) # # 记录每个结果的评分信息 # metadata = doc.metadata.get('metadata', {}) # title = 'N/A' # if isinstance(metadata, str): # try: # import json # inner_metadata = json.loads(metadata) # title = inner_metadata.get('title', 'N/A') # except: # pass # else: # title = metadata.get('title', 'N/A') # _get_logger().info(f"混合搜索评分: 标题='{title}', 距离={score:.4f}, 相似度={similarity:.4f}") # _get_logger().info(f"Hybrid search returned {len(formatted_results)} results") return formatted_results except Exception as e: _get_logger().error(f"Error in hybrid search: {e}") # 回退到传统的向量搜索 _get_logger().info("Falling back to traditional vector search") return self.similarity_search(param, query_text, top_k=top_k) def _process_metadata(self,metadata): """处理 metadata:将 list 类型的 hierarchy 转换为 Milvus 支持的 string 类型""" processed_metadata = metadata.copy() if "hierarchy" in processed_metadata and isinstance(processed_metadata["hierarchy"], list): processed_metadata["hierarchy"] = " > ".join(processed_metadata["hierarchy"]) for key, value in processed_metadata.items(): if value is None: processed_metadata[key] = "" elif isinstance(value, dict): processed_metadata[key] = json.dumps(value, ensure_ascii=False) return processed_metadata