| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367 |
- import time
- from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility
- # from sentence_transformers import SentenceTransformer
- import numpy as np
- from typing import List, Dict, Any, Optional
- import json
- from foundation.base.config import config_handler
- from foundation.logger.loggering import server_logger as logger
- from foundation.rag.vector.base_vector import BaseVectorDB
- from foundation.models.base_online_platform import BaseApiPlatform
- class MilvusVectorManager(BaseVectorDB):
- def __init__(self, base_api_platform :BaseApiPlatform):
- """
- 初始化 Milvus 连接
- """
- self.base_api_platform = base_api_platform
- self.host = config_handler.get('milvus', 'MILVUS_HOST', 'localhost')
- self.port = int(config_handler.get('milvus', 'MILVUS_PORT', '19530'))
- self.milvus_db = config_handler.get('milvus', 'MILVUS_DB', 'default')
- self.user = config_handler.get('milvus', 'MILVUS_USER')
- self.password = config_handler.get('milvus', 'MILVUS_PASSWORD')
-
- # 初始化文本向量化模型
- #self.model = SentenceTransformer('all-MiniLM-L6-v2') # 可以替换为其他模型
-
- # 连接到 Milvus
- self.connect()
-
- def connect(self):
- """连接到 Milvus 服务器
- ,
- password=self.password
- alias="default",
- """
- try:
- connections.connect(
- alias="default",
- host=self.host,
- port=self.port,
- user=self.user,
- db_name="lq_db"
- )
- logger.info(f"Connected to Milvus at {self.host}:{self.port}")
- except Exception as e:
- logger.error(f"Failed to connect to Milvus: {e}")
- raise
-
- def create_collection(self, collection_name: str, dimension: int = 768,
- description: str = "Vector collection for text embeddings"):
- """
- 创建向量集合
- """
- try:
- # 检查集合是否已存在
- if utility.has_collection(collection_name):
- logger.info(f"Collection {collection_name} already exists")
- utility.drop_collection(collection_name)
- logger.info(f"Collection '{collection_name}' dropped successfully")
-
-
- # 定义字段
- fields = [
- FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
- FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=dimension),
- FieldSchema(name="text_content", dtype=DataType.VARCHAR, max_length=65535),
- FieldSchema(name="metadata", dtype=DataType.JSON),
- FieldSchema(name="created_at", dtype=DataType.INT64)
- ]
-
- # 创建集合模式
- schema = CollectionSchema(
- fields=fields,
- description=description
- )
-
- # 创建集合
- collection = Collection(
- name=collection_name,
- schema=schema
- )
-
- # 创建索引
- index_params = {
- "index_type": "IVF_FLAT",
- "metric_type": "COSINE",
- "params": {"nlist": 100}
- }
-
- collection.create_index(field_name="embedding", index_params=index_params)
- logger.info(f"Collection {collection_name} created successfully!")
-
- except Exception as e:
- logger.error(f"Error creating collection: {e}")
- raise
-
-
-
- def add_document(self , param: Dict[str, Any] , document: Dict[str, Any]):
- """
- 插入单个文本及其向量
- """
- try:
- collection_name = param.get('collection_name')
- text = document.get('content')
- metadata = document.get('metadata')
- collection = Collection(collection_name)
- created_at = None
-
- # 转换文本为向量
- embedding = self.text_to_vector(text)
- #logger.info(f"Text converted to embedding:{isinstance(embedding, list)} ,{len(embedding)}")
- #logger.info(f"Text converted to embedding:{embedding}")
- # 准备数据
- data = [
- [embedding], # embedding
- [text], # text_content
- [metadata or {}], # metadata
- [created_at or int(time.time())] # created_at
- ]
- logger.info(f"Preparing to insert text_contents:{len(data[0])} ,{len(data[1])},{len(data[2])},{len(data[3])}")
-
- # 插入数据
- insert_result = collection.insert(data)
- collection.flush() # 确保数据被写入
-
- logger.info(f"Text inserted with ID: {insert_result.primary_keys[0]}")
- return insert_result.primary_keys[0]
-
- except Exception as e:
- logger.error(f"Error inserting text: {e}")
- return None
-
- def add_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]]):
- """
- 批量插入文本
- texts: [{'text': '...', 'metadata': {...}}, ...]
- """
- try:
- collection_name = param.get('collection_name')
- collection = Collection(collection_name)
-
- text_contents = []
- embeddings = []
- metadatas = []
- timestamps = []
-
- for item in documents:
- text = item['content']
- metadata = item.get('metadata', {})
-
- # 转换文本为向量
- embedding = self.text_to_vector(text)
-
- text_contents.append(text)
- embeddings.append(embedding)
- metadatas.append(metadata)
- timestamps.append(int(time.time()))
-
-
- # 准备批量数据
- data = [embeddings, text_contents, metadatas, timestamps]
- #logger.info(f"Preparing to insert text_contents:{len(text_contents)} ,{len(embeddings)},{len(metadatas)},{len(timestamps)}")
-
- # 批量插入
- insert_result = collection.insert(data)
- collection.flush() # 确保数据被写入
-
- logger.info(f"Batch inserted {len(text_contents)} records, IDs: {insert_result.primary_keys}")
- return insert_result.primary_keys
-
- except Exception as e:
- logger.error(f"Error batch inserting: {e}")
- return None
-
- def similarity_search(self, param: Dict[str, Any], query_text: str , min_score=0.5 ,
- top_k=5, filters: Dict[str, Any] = None):
- """
- 搜索相似文本
- """
- try:
- collection_name = param.get('collection_name')
- collection = Collection(collection_name)
-
- # 加载集合到内存(如果还没有加载)
- collection.load()
-
- # 转换查询文本为向量
- query_embedding = self.text_to_vector(query_text)
-
- # 搜索参数
- search_params = {
- "metric_type": "COSINE",
- "params": {"nprobe": 10}
- }
- # 构建过滤表达式
- filter_expr = self._create_filter(filters)
-
- # 执行搜索
- results = collection.search(
- data=[query_embedding],
- anns_field="vector",
- param=search_params,
- limit=top_k,
- expr=filter_expr,
- output_fields=["text_content", "metadata"]
- )
-
- # 格式化结果
- formatted_results = []
- for hits in results:
- for hit in hits:
- formatted_results.append({
- 'id': hit.id,
- 'text_content': hit.entity.get('text_content'),
- 'metadata': hit.entity.get('metadata'),
- 'distance': hit.distance,
- 'similarity': 1 - hit.distance # 转换为相似度
- })
-
- return formatted_results
-
- except Exception as e:
- logger.error(f"Error searching: {e}")
- return []
-
- def retriever(self, param: Dict[str, Any], query_text: str,
- top_k: int = 5, filters: Dict[str, Any] = None):
- """
- 带过滤条件的相似搜索
- """
- try:
- collection_name = param.get('collection_name')
- collection = Collection(collection_name)
- collection.load()
-
- query_embedding = self.text_to_vector(query_text)
-
- # 构建过滤表达式
- filter_expr = self._create_filter(filters)
-
- search_params = {
- "metric_type": "COSINE",
- "params": {"nprobe": 10}
- }
-
- results = collection.search(
- data=[query_embedding],
- anns_field="vector",
- param=search_params,
- limit=top_k,
- expr=filter_expr,
- output_fields=["text_content", "metadata"]
- )
-
- formatted_results = []
- for hits in results:
- for hit in hits:
- formatted_results.append({
- 'id': hit.id,
- 'text_content': hit.entity.get('text_content'),
- 'metadata': hit.entity.get('metadata'),
- 'distance': hit.distance,
- 'similarity': 1 - hit.distance
- })
-
- return formatted_results
-
- except Exception as e:
- logger.error(f"Error searching with filter: {e}")
- return []
-
-
- def _create_filter(self, filters: Dict[str, Any]) -> str:
- """
- 创建过滤条件
- """
- # 构建过滤表达式
- filter_expr = ""
- if filters:
- conditions = []
- for key, value in filters.items():
- if isinstance(value, str):
- conditions.append(f'metadata["{key}"] == "{value}"')
- elif isinstance(value, (int, float)):
- conditions.append(f'metadata["{key}"] == {value}')
- else:
- conditions.append(f'metadata["{key}"] == "{json.dumps(value)}"')
- filter_expr = " and ".join(conditions)
-
- return filter_expr
- def db_test(self , query_text):
- query = query_text
- import time
- # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY)
- from foundation.models.silicon_flow import SiliconFlowAPI
- client = SiliconFlowAPI()
- # 初始化 Milvus 管理器
- milvus_manager = MilvusVectorManager(base_api_platform=client)
-
- # 创建集合
- collection_name = 'text_embeddings'
- milvus_manager.create_collection(collection_name, dimension=768)
-
- param = {"collection_name": collection_name}
-
- # 插入单个文本
- sample_text = "这是一个关于人工智能的文档。"
- milvus_manager.add_document(
- param,
- {"content":sample_text , "metadata": {'category': 'AI', 'source': 'example'}}
- )
-
- # 批量插入文本
- sample_texts = [
- {
- 'content': '机器学习是人工智能的一个重要分支。',
- 'metadata': {'category': 'ML', 'author': 'John'}
- },
- {
- 'content': '深度学习在图像识别领域取得了显著成果。',
- 'metadata': {'category': 'Deep Learning', 'author': 'Jane'}
- },
- {
- 'content': '自然语言处理技术在聊天机器人中得到广泛应用。',
- 'metadata': {'category': 'NLP', 'author': 'Bob'}
- }
- ,
- {
- 'content': 'AI发展速度快,但需要更多的计算资源。',
- 'metadata': {'category': 'AI', 'author': 'Bob'}
- }
- ]
-
-
- milvus_manager.add_batch_documents(param=param, documents=sample_texts)
-
- # 搜索相似文本
- query = "人工智能相关的技术"
- similar_docs = milvus_manager.similarity_search(param, query, top_k=5)
-
- logger.info(f"Similar documents found-{len(similar_docs)}:")
- for doc in similar_docs:
- logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {doc['similarity']:.3f}")
- logger.info(f"{'=' *20}")
- # 带过滤条件的搜索
- filtered_docs = milvus_manager.retriever(
- param,
- query,
- top_k=5,
- filters={'category': 'AI'}
- )
-
- logger.info(f"\nFiltered similar documents-{len(filtered_docs)}:")
- for doc in filtered_docs:
- logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {doc['similarity']:.3f}")
|