import time from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility # from sentence_transformers import SentenceTransformer import numpy as np from typing import List, Dict, Any, Optional import json from foundation.infrastructure.config.config import config_handler from foundation.observability.logger.loggering import server_logger as logger from foundation.database.base.vector.base_vector import BaseVectorDB from foundation.ai.models.base_online_platform import BaseApiPlatform class MilvusVectorManager(BaseVectorDB): def __init__(self, base_api_platform :BaseApiPlatform): """ 初始化 Milvus 连接 """ self.base_api_platform = base_api_platform self.host = config_handler.get('milvus', 'MILVUS_HOST', 'localhost') self.port = int(config_handler.get('milvus', 'MILVUS_PORT', '19530')) self.milvus_db = config_handler.get('milvus', 'MILVUS_DB', 'default') self.user = config_handler.get('milvus', 'MILVUS_USER') self.password = config_handler.get('milvus', 'MILVUS_PASSWORD') # 初始化文本向量化模型 #self.model = SentenceTransformer('all-MiniLM-L6-v2') # 可以替换为其他模型 # 连接到 Milvus self.connect() def connect(self): """连接到 Milvus 服务器 , password=self.password alias="default", """ try: connections.connect( alias="default", host=self.host, port=self.port, user=self.user, db_name="lq_db" ) logger.info(f"Connected to Milvus at {self.host}:{self.port}") except Exception as e: logger.error(f"Failed to connect to Milvus: {e}") raise def create_collection(self, collection_name: str, dimension: int = 768, description: str = "Vector collection for text embeddings"): """ 创建向量集合 """ try: # 检查集合是否已存在 if utility.has_collection(collection_name): logger.info(f"Collection {collection_name} already exists") utility.drop_collection(collection_name) logger.info(f"Collection '{collection_name}' dropped successfully") # 定义字段 fields = [ FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True), FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension), FieldSchema(name="text_content", dtype=DataType.VARCHAR, max_length=65535), FieldSchema(name="metadata", dtype=DataType.JSON), FieldSchema(name="created_at", dtype=DataType.INT64) ] # 创建集合模式 schema = CollectionSchema( fields=fields, description=description ) # 创建集合 collection = Collection( name=collection_name, schema=schema ) # 创建索引 index_params = { "index_type": "IVF_FLAT", "metric_type": "COSINE", "params": {"nlist": 100} } collection.create_index(field_name="vector", index_params=index_params) logger.info(f"Collection {collection_name} created successfully!") except Exception as e: logger.error(f"Error creating collection: {e}") raise def add_document(self , param: Dict[str, Any] , document: Dict[str, Any]): """ 插入单个文本及其向量 """ try: collection_name = param.get('collection_name') text = document.get('content') metadata = document.get('metadata') collection = Collection(collection_name) created_at = None # 转换文本为向量 embedding = self.text_to_vector(text) #logger.info(f"Text converted to embedding:{isinstance(embedding, list)} ,{len(embedding)}") #logger.info(f"Text converted to embedding:{embedding}") # 准备数据 data = [ [embedding], # embedding [text], # text_content [metadata or {}], # metadata [created_at or int(time.time())] # created_at ] logger.info(f"Preparing to insert text_contents:{len(data[0])} ,{len(data[1])},{len(data[2])},{len(data[3])}") # 插入数据 insert_result = collection.insert(data) collection.flush() # 确保数据被写入 logger.info(f"Text inserted with ID: {insert_result.primary_keys[0]}") return insert_result.primary_keys[0] except Exception as e: logger.error(f"Error inserting text: {e}") return None def add_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]]): """ 批量插入文本 texts: [{'text': '...', 'metadata': {...}}, ...] """ try: collection_name = param.get('collection_name') collection = Collection(collection_name) text_contents = [] embeddings = [] metadatas = [] timestamps = [] for item in documents: text = item['content'] metadata = item.get('metadata', {}) # 转换文本为向量 embedding = self.text_to_vector(text) text_contents.append(text) embeddings.append(embedding) metadatas.append(metadata) timestamps.append(int(time.time())) # 准备批量数据 data = [embeddings, text_contents, metadatas, timestamps] #logger.info(f"Preparing to insert text_contents:{len(text_contents)} ,{len(embeddings)},{len(metadatas)},{len(timestamps)}") # 批量插入 insert_result = collection.insert(data) collection.flush() # 确保数据被写入 logger.info(f"Batch inserted {len(text_contents)} records, IDs: {insert_result.primary_keys}") return insert_result.primary_keys except Exception as e: logger.error(f"Error batch inserting: {e}") return None def similarity_search(self, param: Dict[str, Any], query_text: str , min_score=0.5 , top_k=5, filters: Dict[str, Any] = None): """ 搜索相似文本 """ try: collection_name = param.get('collection_name') collection = Collection(collection_name) # 加载集合到内存(如果还没有加载) collection.load() # 转换查询文本为向量 query_embedding = self.text_to_vector(query_text) # 搜索参数 search_params = { "metric_type": "COSINE", "params": {"nprobe": 10} } # 构建过滤表达式 filter_expr = self._create_filter(filters) # 执行搜索 results = collection.search( data=[query_embedding], anns_field="vector", param=search_params, limit=top_k, expr=filter_expr, output_fields=["text_content", "metadata"] ) # 格式化结果 formatted_results = [] for hits in results: for hit in hits: formatted_results.append({ 'id': hit.id, 'text_content': hit.entity.get('text_content'), 'metadata': hit.entity.get('metadata'), 'distance': hit.distance, 'similarity': 1 - hit.distance # 转换为相似度 }) return formatted_results except Exception as e: logger.error(f"Error searching: {e}") return [] def retriever(self, param: Dict[str, Any], query_text: str, top_k: int = 5, filters: Dict[str, Any] = None): """ 带过滤条件的相似搜索 """ try: collection_name = param.get('collection_name') collection = Collection(collection_name) collection.load() query_embedding = self.text_to_vector(query_text) # 构建过滤表达式 filter_expr = self._create_filter(filters) search_params = { "metric_type": "COSINE", "params": {"nprobe": 10} } results = collection.search( data=[query_embedding], anns_field="vector", param=search_params, limit=top_k, expr=filter_expr, output_fields=["text_content", "metadata"] ) formatted_results = [] for hits in results: for hit in hits: formatted_results.append({ 'id': hit.id, 'text_content': hit.entity.get('text_content'), 'metadata': hit.entity.get('metadata'), 'distance': hit.distance, 'similarity': 1 - hit.distance }) return formatted_results except Exception as e: logger.error(f"Error searching with filter: {e}") return [] def _create_filter(self, filters: Dict[str, Any]) -> str: """ 创建过滤条件 """ # 构建过滤表达式 filter_expr = "" if filters: conditions = [] for key, value in filters.items(): if isinstance(value, str): conditions.append(f'metadata["{key}"] == "{value}"') elif isinstance(value, (int, float)): conditions.append(f'metadata["{key}"] == {value}') else: conditions.append(f'metadata["{key}"] == "{json.dumps(value)}"') filter_expr = " and ".join(conditions) return filter_expr def db_test(self , query_text): query = query_text import time # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY) from foundation.models.silicon_flow import SiliconFlowAPI client = SiliconFlowAPI() # 初始化 Milvus 管理器 milvus_manager = MilvusVectorManager(base_api_platform=client) # 创建集合 collection_name = 'text_embeddings' milvus_manager.create_collection(collection_name, dimension=768) param = {"collection_name": collection_name} # 插入单个文本 sample_text = "这是一个关于人工智能的文档。" milvus_manager.add_document( param, {"content":sample_text , "metadata": {'category': 'AI', 'source': 'example'}} ) # 批量插入文本 sample_texts = [ { 'content': '机器学习是人工智能的一个重要分支。', 'metadata': {'category': 'ML', 'author': 'John'} }, { 'content': '深度学习在图像识别领域取得了显著成果。', 'metadata': {'category': 'Deep Learning', 'author': 'Jane'} }, { 'content': '自然语言处理技术在聊天机器人中得到广泛应用。', 'metadata': {'category': 'NLP', 'author': 'Bob'} } , { 'content': 'AI发展速度快,但需要更多的计算资源。', 'metadata': {'category': 'AI', 'author': 'Bob'} } ] milvus_manager.add_batch_documents(param=param, documents=sample_texts) # 搜索相似文本 query = "人工智能相关的技术" similar_docs = milvus_manager.similarity_search(param, query, top_k=5) logger.info(f"Similar documents found-{len(similar_docs)}:") for doc in similar_docs: logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {doc['similarity']:.3f}") logger.info(f"{'=' *20}") # 带过滤条件的搜索 filtered_docs = milvus_manager.retriever( param, query, top_k=5, filters={'category': 'AI'} ) logger.info(f"\nFiltered similar documents-{len(filtered_docs)}:") for doc in filtered_docs: logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {doc['similarity']:.3f}")