milvus_vector.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. import time
  2. from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility
  3. from sentence_transformers import SentenceTransformer
  4. import numpy as np
  5. from typing import List, Dict, Any, Optional
  6. import json
  7. from foundation.base.config import config_handler
  8. from foundation.logger.loggering import server_logger as logger
  9. from foundation.rag.vector.base_vector import BaseVectorDB
  10. from foundation.models.base_online_platform import BaseApiPlatform
  11. class MilvusVectorManager(BaseVectorDB):
  12. def __init__(self, base_api_platform :BaseApiPlatform):
  13. """
  14. 初始化 Milvus 连接
  15. """
  16. self.base_api_platform = base_api_platform
  17. self.host = config_handler.get('milvus', 'MILVUS_HOST', 'localhost')
  18. self.port = int(config_handler.get('milvus', 'MILVUS_PORT', '19530'))
  19. self.user = config_handler.get('milvus', 'MILVUS_USER')
  20. self.password = config_handler.get('milvus', 'MILVUS_PASSWORD')
  21. # 初始化文本向量化模型
  22. #self.model = SentenceTransformer('all-MiniLM-L6-v2') # 可以替换为其他模型
  23. # 连接到 Milvus
  24. self.connect()
  25. def connect(self):
  26. """连接到 Milvus 服务器"""
  27. try:
  28. connections.connect(
  29. alias="default",
  30. host=self.host,
  31. port=self.port,
  32. user=self.user,
  33. password=self.password
  34. )
  35. logger.info(f"Connected to Milvus at {self.host}:{self.port}")
  36. except Exception as e:
  37. logger.error(f"Failed to connect to Milvus: {e}")
  38. raise
  39. def create_collection(self, collection_name: str, dimension: int = 768,
  40. description: str = "Vector collection for text embeddings"):
  41. """
  42. 创建向量集合
  43. """
  44. try:
  45. # 检查集合是否已存在
  46. if utility.has_collection(collection_name):
  47. logger.info(f"Collection {collection_name} already exists")
  48. return
  49. # 定义字段
  50. fields = [
  51. FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
  52. FieldSchema(name="text_content", dtype=DataType.VARCHAR, max_length=65535),
  53. FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=dimension),
  54. FieldSchema(name="metadata", dtype=DataType.JSON),
  55. FieldSchema(name="created_at", dtype=DataType.INT64)
  56. ]
  57. # 创建集合模式
  58. schema = CollectionSchema(
  59. fields=fields,
  60. description=description
  61. )
  62. # 创建集合
  63. collection = Collection(
  64. name=collection_name,
  65. schema=schema
  66. )
  67. # 创建索引
  68. index_params = {
  69. "index_type": "IVF_FLAT",
  70. "metric_type": "COSINE",
  71. "params": {"nlist": 100}
  72. }
  73. collection.create_index(field_name="embedding", index_params=index_params)
  74. logger.info(f"Collection {collection_name} created successfully!")
  75. except Exception as e:
  76. logger.error(f"Error creating collection: {e}")
  77. raise
  78. def add_document(self , param: Dict[str, Any] , document: Dict[str, Any]):
  79. """
  80. 插入单个文本及其向量
  81. """
  82. try:
  83. collection_name = param.get('collection_name')
  84. text = document.get('content')
  85. metadata = document.get('metadata')
  86. collection = Collection(collection_name)
  87. created_at = None
  88. # 转换文本为向量
  89. embedding = self.text_to_vector(text)
  90. # 准备数据
  91. data = [
  92. [text], # text_content
  93. [embedding], # embedding
  94. [metadata or {}], # metadata
  95. [created_at or int(time.time())] # created_at
  96. ]
  97. # 插入数据
  98. insert_result = collection.insert(data)
  99. collection.flush() # 确保数据被写入
  100. logger.info(f"Text inserted with ID: {insert_result.primary_keys[0]}")
  101. return insert_result.primary_keys[0]
  102. except Exception as e:
  103. logger.error(f"Error inserting text: {e}")
  104. return None
  105. def add_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]]):
  106. """
  107. 批量插入文本
  108. texts: [{'text': '...', 'metadata': {...}}, ...]
  109. """
  110. try:
  111. collection_name = param.get('collection_name')
  112. collection = Collection(collection_name)
  113. text_contents = []
  114. embeddings = []
  115. metadatas = []
  116. timestamps = []
  117. for item in documents:
  118. text = item['content']
  119. metadata = item.get('metadata', {})
  120. # 转换文本为向量
  121. embedding = self.text_to_vector(text)
  122. text_contents.append(text)
  123. embeddings.append(embedding)
  124. metadatas.append(metadata)
  125. timestamps.append(int(time.time()))
  126. # 准备批量数据
  127. data = [text_contents, embeddings, metadatas, timestamps]
  128. # 批量插入
  129. insert_result = collection.insert(data)
  130. collection.flush() # 确保数据被写入
  131. logger.info(f"Batch inserted {len(text_contents)} records, IDs: {insert_result.primary_keys}")
  132. return insert_result.primary_keys
  133. except Exception as e:
  134. logger.error(f"Error batch inserting: {e}")
  135. return None
  136. def similarity_search(self, param: Dict[str, Any], query_text: str , min_score=0.5 ,
  137. top_k=5, filters: Dict[str, Any] = None):
  138. """
  139. 搜索相似文本
  140. """
  141. try:
  142. collection_name = param.get('collection_name')
  143. collection = Collection(collection_name)
  144. # 加载集合到内存(如果还没有加载)
  145. collection.load()
  146. # 转换查询文本为向量
  147. query_embedding = self.text_to_vector(query_text)
  148. # 搜索参数
  149. search_params = {
  150. "metric_type": "COSINE",
  151. "params": {"nprobe": 10}
  152. }
  153. # 构建过滤表达式
  154. filter_expr = self._create_filter(filters)
  155. # 执行搜索
  156. results = collection.search(
  157. data=[query_embedding],
  158. anns_field="embedding",
  159. param=search_params,
  160. limit=top_k,
  161. expr=filter_expr,
  162. output_fields=["text_content", "metadata"]
  163. )
  164. # 格式化结果
  165. formatted_results = []
  166. for hits in results:
  167. for hit in hits:
  168. formatted_results.append({
  169. 'id': hit.id,
  170. 'text_content': hit.entity.get('text_content'),
  171. 'metadata': hit.entity.get('metadata'),
  172. 'distance': hit.distance,
  173. 'similarity': 1 - hit.distance # 转换为相似度
  174. })
  175. return formatted_results
  176. except Exception as e:
  177. logger.error(f"Error searching: {e}")
  178. return []
  179. def retriever(self, param: Dict[str, Any], query_text: str,
  180. top_k: int = 5, filters: Dict[str, Any] = None):
  181. """
  182. 带过滤条件的相似搜索
  183. """
  184. try:
  185. collection_name = param.get('collection_name')
  186. collection = Collection(collection_name)
  187. collection.load()
  188. query_embedding = self.text_to_vector(query_text)
  189. # 构建过滤表达式
  190. filter_expr = self._create_filter(filters)
  191. search_params = {
  192. "metric_type": "COSINE",
  193. "params": {"nprobe": 10}
  194. }
  195. results = collection.search(
  196. data=[query_embedding],
  197. anns_field="embedding",
  198. param=search_params,
  199. limit=top_k,
  200. expr=filter_expr,
  201. output_fields=["text_content", "metadata"]
  202. )
  203. formatted_results = []
  204. for hits in results:
  205. for hit in hits:
  206. formatted_results.append({
  207. 'id': hit.id,
  208. 'text_content': hit.entity.get('text_content'),
  209. 'metadata': hit.entity.get('metadata'),
  210. 'distance': hit.distance,
  211. 'similarity': 1 - hit.distance
  212. })
  213. return formatted_results
  214. except Exception as e:
  215. logger.error(f"Error searching with filter: {e}")
  216. return []
  217. def _create_filter(self, filters: Dict[str, Any]) -> str:
  218. """
  219. 创建过滤条件
  220. """
  221. # 构建过滤表达式
  222. filter_expr = ""
  223. if filters:
  224. conditions = []
  225. for key, value in filters.items():
  226. if isinstance(value, str):
  227. conditions.append(f'metadata["{key}"] == "{value}"')
  228. elif isinstance(value, (int, float)):
  229. conditions.append(f'metadata["{key}"] == {value}')
  230. else:
  231. conditions.append(f'metadata["{key}"] == "{json.dumps(value)}"')
  232. filter_expr = " and ".join(conditions)
  233. return filter_expr
  234. def db_test(self):
  235. import time
  236. # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY)
  237. from foundation.models.silicon_flow import SiliconFlowAPI
  238. client = SiliconFlowAPI()
  239. # 初始化 Milvus 管理器
  240. milvus_manager = MilvusVectorManager(base_api_platform=client)
  241. # 创建集合
  242. collection_name = 'text_embeddings'
  243. milvus_manager.create_collection(collection_name, dimension=384)
  244. # 插入单个文本
  245. sample_text = "这是一个关于人工智能的文档。"
  246. milvus_manager.insert_text(
  247. collection_name,
  248. sample_text,
  249. metadata={'category': 'AI', 'source': 'example'}
  250. )
  251. # 批量插入文本
  252. sample_texts = [
  253. {
  254. 'text': '机器学习是人工智能的一个重要分支。',
  255. 'metadata': {'category': 'ML', 'author': 'John'}
  256. },
  257. {
  258. 'text': '深度学习在图像识别领域取得了显著成果。',
  259. 'metadata': {'category': 'Deep Learning', 'author': 'Jane'}
  260. },
  261. {
  262. 'text': '自然语言处理技术在聊天机器人中得到广泛应用。',
  263. 'metadata': {'category': 'NLP', 'author': 'Bob'}
  264. }
  265. ]
  266. param = {"collection_name": collection_name}
  267. milvus_manager.add_batch_documents(param, sample_texts)
  268. # 搜索相似文本
  269. query = "人工智能相关的技术"
  270. similar_docs = milvus_manager.similarity_search(param, query, top_k=3)
  271. logger.info("Similar documents found:")
  272. for doc in similar_docs:
  273. logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {doc['similarity']:.3f}")
  274. # 带过滤条件的搜索
  275. filtered_docs = milvus_manager.search_with_filter(
  276. collection_name,
  277. query,
  278. top_k=3,
  279. filters={'category': 'AI'}
  280. )
  281. logger.info("\nFiltered similar documents:")
  282. for doc in filtered_docs:
  283. logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {doc['similarity']:.3f}")