|
|
@@ -0,0 +1,347 @@
|
|
|
+import time
|
|
|
+from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility
|
|
|
+from sentence_transformers import SentenceTransformer
|
|
|
+import numpy as np
|
|
|
+from typing import List, Dict, Any, Optional
|
|
|
+import json
|
|
|
+from foundation.base.config import config_handler
|
|
|
+from foundation.logger.loggering import server_logger as logger
|
|
|
+from foundation.rag.vector.base_vector import BaseVectorDB
|
|
|
+from foundation.models.base_online_platform import BaseApiPlatform
|
|
|
+
|
|
|
+class MilvusVectorManager(BaseVectorDB):
|
|
|
+ def __init__(self, base_api_platform :BaseApiPlatform):
|
|
|
+ """
|
|
|
+ 初始化 Milvus 连接
|
|
|
+ """
|
|
|
+ self.base_api_platform = base_api_platform
|
|
|
+
|
|
|
+ self.host = config_handler.get('milvus', 'MILVUS_HOST', 'localhost')
|
|
|
+ self.port = int(config_handler.get('milvus', 'MILVUS_PORT', '19530'))
|
|
|
+ self.user = config_handler.get('milvus', 'MILVUS_USER')
|
|
|
+ self.password = config_handler.get('milvus', 'MILVUS_PASSWORD')
|
|
|
+
|
|
|
+ # 初始化文本向量化模型
|
|
|
+ #self.model = SentenceTransformer('all-MiniLM-L6-v2') # 可以替换为其他模型
|
|
|
+
|
|
|
+ # 连接到 Milvus
|
|
|
+ self.connect()
|
|
|
+
|
|
|
+ def connect(self):
|
|
|
+ """连接到 Milvus 服务器"""
|
|
|
+ try:
|
|
|
+ connections.connect(
|
|
|
+ alias="default",
|
|
|
+ host=self.host,
|
|
|
+ port=self.port,
|
|
|
+ user=self.user,
|
|
|
+ password=self.password
|
|
|
+ )
|
|
|
+ logger.info(f"Connected to Milvus at {self.host}:{self.port}")
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Failed to connect to Milvus: {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ def create_collection(self, collection_name: str, dimension: int = 768,
|
|
|
+ description: str = "Vector collection for text embeddings"):
|
|
|
+ """
|
|
|
+ 创建向量集合
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ # 检查集合是否已存在
|
|
|
+ if utility.has_collection(collection_name):
|
|
|
+ logger.info(f"Collection {collection_name} already exists")
|
|
|
+ return
|
|
|
+
|
|
|
+ # 定义字段
|
|
|
+ fields = [
|
|
|
+ FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
|
|
|
+ FieldSchema(name="text_content", dtype=DataType.VARCHAR, max_length=65535),
|
|
|
+ FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=dimension),
|
|
|
+ FieldSchema(name="metadata", dtype=DataType.JSON),
|
|
|
+ FieldSchema(name="created_at", dtype=DataType.INT64)
|
|
|
+ ]
|
|
|
+
|
|
|
+ # 创建集合模式
|
|
|
+ schema = CollectionSchema(
|
|
|
+ fields=fields,
|
|
|
+ description=description
|
|
|
+ )
|
|
|
+
|
|
|
+ # 创建集合
|
|
|
+ collection = Collection(
|
|
|
+ name=collection_name,
|
|
|
+ schema=schema
|
|
|
+ )
|
|
|
+
|
|
|
+ # 创建索引
|
|
|
+ index_params = {
|
|
|
+ "index_type": "IVF_FLAT",
|
|
|
+ "metric_type": "COSINE",
|
|
|
+ "params": {"nlist": 100}
|
|
|
+ }
|
|
|
+
|
|
|
+ collection.create_index(field_name="embedding", index_params=index_params)
|
|
|
+ logger.info(f"Collection {collection_name} created successfully!")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error creating collection: {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ def add_document(self , param: Dict[str, Any] , document: Dict[str, Any]):
|
|
|
+ """
|
|
|
+ 插入单个文本及其向量
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ collection_name = param.get('collection_name')
|
|
|
+ text = document.get('content')
|
|
|
+ metadata = document.get('metadata')
|
|
|
+ collection = Collection(collection_name)
|
|
|
+ created_at = None
|
|
|
+
|
|
|
+ # 转换文本为向量
|
|
|
+ embedding = self.text_to_vector(text)
|
|
|
+
|
|
|
+ # 准备数据
|
|
|
+ data = [
|
|
|
+ [text], # text_content
|
|
|
+ [embedding], # embedding
|
|
|
+ [metadata or {}], # metadata
|
|
|
+ [created_at or int(time.time())] # created_at
|
|
|
+ ]
|
|
|
+
|
|
|
+ # 插入数据
|
|
|
+ insert_result = collection.insert(data)
|
|
|
+ collection.flush() # 确保数据被写入
|
|
|
+
|
|
|
+ logger.info(f"Text inserted with ID: {insert_result.primary_keys[0]}")
|
|
|
+ return insert_result.primary_keys[0]
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error inserting text: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ def add_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]]):
|
|
|
+ """
|
|
|
+ 批量插入文本
|
|
|
+ texts: [{'text': '...', 'metadata': {...}}, ...]
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ collection_name = param.get('collection_name')
|
|
|
+ collection = Collection(collection_name)
|
|
|
+
|
|
|
+ text_contents = []
|
|
|
+ embeddings = []
|
|
|
+ metadatas = []
|
|
|
+ timestamps = []
|
|
|
+
|
|
|
+ for item in documents:
|
|
|
+ text = item['content']
|
|
|
+ metadata = item.get('metadata', {})
|
|
|
+
|
|
|
+ # 转换文本为向量
|
|
|
+ embedding = self.text_to_vector(text)
|
|
|
+
|
|
|
+ text_contents.append(text)
|
|
|
+ embeddings.append(embedding)
|
|
|
+ metadatas.append(metadata)
|
|
|
+ timestamps.append(int(time.time()))
|
|
|
+
|
|
|
+ # 准备批量数据
|
|
|
+ data = [text_contents, embeddings, metadatas, timestamps]
|
|
|
+
|
|
|
+ # 批量插入
|
|
|
+ insert_result = collection.insert(data)
|
|
|
+ collection.flush() # 确保数据被写入
|
|
|
+
|
|
|
+ logger.info(f"Batch inserted {len(text_contents)} records, IDs: {insert_result.primary_keys}")
|
|
|
+ return insert_result.primary_keys
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error batch inserting: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ def similarity_search(self, param: Dict[str, Any], query_text: str , min_score=0.5 ,
|
|
|
+ top_k=5, filters: Dict[str, Any] = None):
|
|
|
+ """
|
|
|
+ 搜索相似文本
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ collection_name = param.get('collection_name')
|
|
|
+ collection = Collection(collection_name)
|
|
|
+
|
|
|
+ # 加载集合到内存(如果还没有加载)
|
|
|
+ collection.load()
|
|
|
+
|
|
|
+ # 转换查询文本为向量
|
|
|
+ query_embedding = self.text_to_vector(query_text)
|
|
|
+
|
|
|
+ # 搜索参数
|
|
|
+ search_params = {
|
|
|
+ "metric_type": "COSINE",
|
|
|
+ "params": {"nprobe": 10}
|
|
|
+ }
|
|
|
+ # 构建过滤表达式
|
|
|
+ filter_expr = self._create_filter(filters)
|
|
|
+
|
|
|
+ # 执行搜索
|
|
|
+ results = collection.search(
|
|
|
+ data=[query_embedding],
|
|
|
+ anns_field="embedding",
|
|
|
+ param=search_params,
|
|
|
+ limit=top_k,
|
|
|
+ expr=filter_expr,
|
|
|
+ output_fields=["text_content", "metadata"]
|
|
|
+ )
|
|
|
+
|
|
|
+ # 格式化结果
|
|
|
+ formatted_results = []
|
|
|
+ for hits in results:
|
|
|
+ for hit in hits:
|
|
|
+ formatted_results.append({
|
|
|
+ 'id': hit.id,
|
|
|
+ 'text_content': hit.entity.get('text_content'),
|
|
|
+ 'metadata': hit.entity.get('metadata'),
|
|
|
+ 'distance': hit.distance,
|
|
|
+ 'similarity': 1 - hit.distance # 转换为相似度
|
|
|
+ })
|
|
|
+
|
|
|
+ return formatted_results
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error searching: {e}")
|
|
|
+ return []
|
|
|
+
|
|
|
+ def retriever(self, param: Dict[str, Any], query_text: str,
|
|
|
+ top_k: int = 5, filters: Dict[str, Any] = None):
|
|
|
+ """
|
|
|
+ 带过滤条件的相似搜索
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ collection_name = param.get('collection_name')
|
|
|
+ collection = Collection(collection_name)
|
|
|
+ collection.load()
|
|
|
+
|
|
|
+ query_embedding = self.text_to_vector(query_text)
|
|
|
+
|
|
|
+ # 构建过滤表达式
|
|
|
+ filter_expr = self._create_filter(filters)
|
|
|
+
|
|
|
+ search_params = {
|
|
|
+ "metric_type": "COSINE",
|
|
|
+ "params": {"nprobe": 10}
|
|
|
+ }
|
|
|
+
|
|
|
+ results = collection.search(
|
|
|
+ data=[query_embedding],
|
|
|
+ anns_field="embedding",
|
|
|
+ param=search_params,
|
|
|
+ limit=top_k,
|
|
|
+ expr=filter_expr,
|
|
|
+ output_fields=["text_content", "metadata"]
|
|
|
+ )
|
|
|
+
|
|
|
+ formatted_results = []
|
|
|
+ for hits in results:
|
|
|
+ for hit in hits:
|
|
|
+ formatted_results.append({
|
|
|
+ 'id': hit.id,
|
|
|
+ 'text_content': hit.entity.get('text_content'),
|
|
|
+ 'metadata': hit.entity.get('metadata'),
|
|
|
+ 'distance': hit.distance,
|
|
|
+ 'similarity': 1 - hit.distance
|
|
|
+ })
|
|
|
+
|
|
|
+ return formatted_results
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Error searching with filter: {e}")
|
|
|
+ return []
|
|
|
+
|
|
|
+
|
|
|
+ def _create_filter(self, filters: Dict[str, Any]) -> str:
|
|
|
+ """
|
|
|
+ 创建过滤条件
|
|
|
+ """
|
|
|
+ # 构建过滤表达式
|
|
|
+ filter_expr = ""
|
|
|
+ if filters:
|
|
|
+ conditions = []
|
|
|
+ for key, value in filters.items():
|
|
|
+ if isinstance(value, str):
|
|
|
+ conditions.append(f'metadata["{key}"] == "{value}"')
|
|
|
+ elif isinstance(value, (int, float)):
|
|
|
+ conditions.append(f'metadata["{key}"] == {value}')
|
|
|
+ else:
|
|
|
+ conditions.append(f'metadata["{key}"] == "{json.dumps(value)}"')
|
|
|
+ filter_expr = " and ".join(conditions)
|
|
|
+
|
|
|
+ return filter_expr
|
|
|
+
|
|
|
+ def db_test(self):
|
|
|
+ import time
|
|
|
+ # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY)
|
|
|
+ from foundation.models.silicon_flow import SiliconFlowAPI
|
|
|
+ client = SiliconFlowAPI()
|
|
|
+ # 初始化 Milvus 管理器
|
|
|
+ milvus_manager = MilvusVectorManager(base_api_platform=client)
|
|
|
+
|
|
|
+ # 创建集合
|
|
|
+ collection_name = 'text_embeddings'
|
|
|
+ milvus_manager.create_collection(collection_name, dimension=384)
|
|
|
+
|
|
|
+ # 插入单个文本
|
|
|
+ sample_text = "这是一个关于人工智能的文档。"
|
|
|
+ milvus_manager.insert_text(
|
|
|
+ collection_name,
|
|
|
+ sample_text,
|
|
|
+ metadata={'category': 'AI', 'source': 'example'}
|
|
|
+ )
|
|
|
+
|
|
|
+ # 批量插入文本
|
|
|
+ sample_texts = [
|
|
|
+ {
|
|
|
+ 'text': '机器学习是人工智能的一个重要分支。',
|
|
|
+ 'metadata': {'category': 'ML', 'author': 'John'}
|
|
|
+ },
|
|
|
+ {
|
|
|
+ 'text': '深度学习在图像识别领域取得了显著成果。',
|
|
|
+ 'metadata': {'category': 'Deep Learning', 'author': 'Jane'}
|
|
|
+ },
|
|
|
+ {
|
|
|
+ 'text': '自然语言处理技术在聊天机器人中得到广泛应用。',
|
|
|
+ 'metadata': {'category': 'NLP', 'author': 'Bob'}
|
|
|
+ }
|
|
|
+ ]
|
|
|
+
|
|
|
+ param = {"collection_name": collection_name}
|
|
|
+ milvus_manager.add_batch_documents(param, sample_texts)
|
|
|
+
|
|
|
+ # 搜索相似文本
|
|
|
+ query = "人工智能相关的技术"
|
|
|
+ similar_docs = milvus_manager.similarity_search(param, query, top_k=3)
|
|
|
+
|
|
|
+ logger.info("Similar documents found:")
|
|
|
+ for doc in similar_docs:
|
|
|
+ logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {doc['similarity']:.3f}")
|
|
|
+
|
|
|
+ # 带过滤条件的搜索
|
|
|
+ filtered_docs = milvus_manager.search_with_filter(
|
|
|
+ collection_name,
|
|
|
+ query,
|
|
|
+ top_k=3,
|
|
|
+ filters={'category': 'AI'}
|
|
|
+ )
|
|
|
+
|
|
|
+ logger.info("\nFiltered similar documents:")
|
|
|
+ for doc in filtered_docs:
|
|
|
+ logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {doc['similarity']:.3f}")
|
|
|
+
|