milvus_vector.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367
  1. import time
  2. from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility
  3. # from sentence_transformers import SentenceTransformer
  4. import numpy as np
  5. from typing import List, Dict, Any, Optional
  6. import json
  7. from foundation.base.config import config_handler
  8. from foundation.logger.loggering import server_logger as logger
  9. from foundation.rag.vector.base_vector import BaseVectorDB
  10. from foundation.models.base_online_platform import BaseApiPlatform
  11. class MilvusVectorManager(BaseVectorDB):
  12. def __init__(self, base_api_platform :BaseApiPlatform):
  13. """
  14. 初始化 Milvus 连接
  15. """
  16. self.base_api_platform = base_api_platform
  17. self.host = config_handler.get('milvus', 'MILVUS_HOST', 'localhost')
  18. self.port = int(config_handler.get('milvus', 'MILVUS_PORT', '19530'))
  19. self.milvus_db = config_handler.get('milvus', 'MILVUS_DB', 'default')
  20. self.user = config_handler.get('milvus', 'MILVUS_USER')
  21. self.password = config_handler.get('milvus', 'MILVUS_PASSWORD')
  22. # 初始化文本向量化模型
  23. #self.model = SentenceTransformer('all-MiniLM-L6-v2') # 可以替换为其他模型
  24. # 连接到 Milvus
  25. self.connect()
  26. def connect(self):
  27. """连接到 Milvus 服务器
  28. ,
  29. password=self.password
  30. alias="default",
  31. """
  32. try:
  33. connections.connect(
  34. alias="default",
  35. host=self.host,
  36. port=self.port,
  37. user=self.user,
  38. db_name="lq_db"
  39. )
  40. logger.info(f"Connected to Milvus at {self.host}:{self.port}")
  41. except Exception as e:
  42. logger.error(f"Failed to connect to Milvus: {e}")
  43. raise
  44. def create_collection(self, collection_name: str, dimension: int = 768,
  45. description: str = "Vector collection for text embeddings"):
  46. """
  47. 创建向量集合
  48. """
  49. try:
  50. # 检查集合是否已存在
  51. if utility.has_collection(collection_name):
  52. logger.info(f"Collection {collection_name} already exists")
  53. utility.drop_collection(collection_name)
  54. logger.info(f"Collection '{collection_name}' dropped successfully")
  55. # 定义字段
  56. fields = [
  57. FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
  58. FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=dimension),
  59. FieldSchema(name="text_content", dtype=DataType.VARCHAR, max_length=65535),
  60. FieldSchema(name="metadata", dtype=DataType.JSON),
  61. FieldSchema(name="created_at", dtype=DataType.INT64)
  62. ]
  63. # 创建集合模式
  64. schema = CollectionSchema(
  65. fields=fields,
  66. description=description
  67. )
  68. # 创建集合
  69. collection = Collection(
  70. name=collection_name,
  71. schema=schema
  72. )
  73. # 创建索引
  74. index_params = {
  75. "index_type": "IVF_FLAT",
  76. "metric_type": "COSINE",
  77. "params": {"nlist": 100}
  78. }
  79. collection.create_index(field_name="embedding", index_params=index_params)
  80. logger.info(f"Collection {collection_name} created successfully!")
  81. except Exception as e:
  82. logger.error(f"Error creating collection: {e}")
  83. raise
  84. def add_document(self , param: Dict[str, Any] , document: Dict[str, Any]):
  85. """
  86. 插入单个文本及其向量
  87. """
  88. try:
  89. collection_name = param.get('collection_name')
  90. text = document.get('content')
  91. metadata = document.get('metadata')
  92. collection = Collection(collection_name)
  93. created_at = None
  94. # 转换文本为向量
  95. embedding = self.text_to_vector(text)
  96. #logger.info(f"Text converted to embedding:{isinstance(embedding, list)} ,{len(embedding)}")
  97. #logger.info(f"Text converted to embedding:{embedding}")
  98. # 准备数据
  99. data = [
  100. [embedding], # embedding
  101. [text], # text_content
  102. [metadata or {}], # metadata
  103. [created_at or int(time.time())] # created_at
  104. ]
  105. logger.info(f"Preparing to insert text_contents:{len(data[0])} ,{len(data[1])},{len(data[2])},{len(data[3])}")
  106. # 插入数据
  107. insert_result = collection.insert(data)
  108. collection.flush() # 确保数据被写入
  109. logger.info(f"Text inserted with ID: {insert_result.primary_keys[0]}")
  110. return insert_result.primary_keys[0]
  111. except Exception as e:
  112. logger.error(f"Error inserting text: {e}")
  113. return None
  114. def add_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]]):
  115. """
  116. 批量插入文本
  117. texts: [{'text': '...', 'metadata': {...}}, ...]
  118. """
  119. try:
  120. collection_name = param.get('collection_name')
  121. collection = Collection(collection_name)
  122. text_contents = []
  123. embeddings = []
  124. metadatas = []
  125. timestamps = []
  126. for item in documents:
  127. text = item['content']
  128. metadata = item.get('metadata', {})
  129. # 转换文本为向量
  130. embedding = self.text_to_vector(text)
  131. text_contents.append(text)
  132. embeddings.append(embedding)
  133. metadatas.append(metadata)
  134. timestamps.append(int(time.time()))
  135. # 准备批量数据
  136. data = [embeddings, text_contents, metadatas, timestamps]
  137. #logger.info(f"Preparing to insert text_contents:{len(text_contents)} ,{len(embeddings)},{len(metadatas)},{len(timestamps)}")
  138. # 批量插入
  139. insert_result = collection.insert(data)
  140. collection.flush() # 确保数据被写入
  141. logger.info(f"Batch inserted {len(text_contents)} records, IDs: {insert_result.primary_keys}")
  142. return insert_result.primary_keys
  143. except Exception as e:
  144. logger.error(f"Error batch inserting: {e}")
  145. return None
  146. def similarity_search(self, param: Dict[str, Any], query_text: str , min_score=0.5 ,
  147. top_k=5, filters: Dict[str, Any] = None):
  148. """
  149. 搜索相似文本
  150. """
  151. try:
  152. collection_name = param.get('collection_name')
  153. collection = Collection(collection_name)
  154. # 加载集合到内存(如果还没有加载)
  155. collection.load()
  156. # 转换查询文本为向量
  157. query_embedding = self.text_to_vector(query_text)
  158. # 搜索参数
  159. search_params = {
  160. "metric_type": "COSINE",
  161. "params": {"nprobe": 10}
  162. }
  163. # 构建过滤表达式
  164. filter_expr = self._create_filter(filters)
  165. # 执行搜索
  166. results = collection.search(
  167. data=[query_embedding],
  168. anns_field="vector",
  169. param=search_params,
  170. limit=top_k,
  171. expr=filter_expr,
  172. output_fields=["text_content", "metadata"]
  173. )
  174. # 格式化结果
  175. formatted_results = []
  176. for hits in results:
  177. for hit in hits:
  178. formatted_results.append({
  179. 'id': hit.id,
  180. 'text_content': hit.entity.get('text_content'),
  181. 'metadata': hit.entity.get('metadata'),
  182. 'distance': hit.distance,
  183. 'similarity': 1 - hit.distance # 转换为相似度
  184. })
  185. return formatted_results
  186. except Exception as e:
  187. logger.error(f"Error searching: {e}")
  188. return []
  189. def retriever(self, param: Dict[str, Any], query_text: str,
  190. top_k: int = 5, filters: Dict[str, Any] = None):
  191. """
  192. 带过滤条件的相似搜索
  193. """
  194. try:
  195. collection_name = param.get('collection_name')
  196. collection = Collection(collection_name)
  197. collection.load()
  198. query_embedding = self.text_to_vector(query_text)
  199. # 构建过滤表达式
  200. filter_expr = self._create_filter(filters)
  201. search_params = {
  202. "metric_type": "COSINE",
  203. "params": {"nprobe": 10}
  204. }
  205. results = collection.search(
  206. data=[query_embedding],
  207. anns_field="vector",
  208. param=search_params,
  209. limit=top_k,
  210. expr=filter_expr,
  211. output_fields=["text_content", "metadata"]
  212. )
  213. formatted_results = []
  214. for hits in results:
  215. for hit in hits:
  216. formatted_results.append({
  217. 'id': hit.id,
  218. 'text_content': hit.entity.get('text_content'),
  219. 'metadata': hit.entity.get('metadata'),
  220. 'distance': hit.distance,
  221. 'similarity': 1 - hit.distance
  222. })
  223. return formatted_results
  224. except Exception as e:
  225. logger.error(f"Error searching with filter: {e}")
  226. return []
  227. def _create_filter(self, filters: Dict[str, Any]) -> str:
  228. """
  229. 创建过滤条件
  230. """
  231. # 构建过滤表达式
  232. filter_expr = ""
  233. if filters:
  234. conditions = []
  235. for key, value in filters.items():
  236. if isinstance(value, str):
  237. conditions.append(f'metadata["{key}"] == "{value}"')
  238. elif isinstance(value, (int, float)):
  239. conditions.append(f'metadata["{key}"] == {value}')
  240. else:
  241. conditions.append(f'metadata["{key}"] == "{json.dumps(value)}"')
  242. filter_expr = " and ".join(conditions)
  243. return filter_expr
  244. def db_test(self , query_text):
  245. query = query_text
  246. import time
  247. # 初始化客户端(需提前设置环境变量 SILICONFLOW_API_KEY)
  248. from foundation.models.silicon_flow import SiliconFlowAPI
  249. client = SiliconFlowAPI()
  250. # 初始化 Milvus 管理器
  251. milvus_manager = MilvusVectorManager(base_api_platform=client)
  252. # 创建集合
  253. collection_name = 'text_embeddings'
  254. milvus_manager.create_collection(collection_name, dimension=768)
  255. param = {"collection_name": collection_name}
  256. # 插入单个文本
  257. sample_text = "这是一个关于人工智能的文档。"
  258. milvus_manager.add_document(
  259. param,
  260. {"content":sample_text , "metadata": {'category': 'AI', 'source': 'example'}}
  261. )
  262. # 批量插入文本
  263. sample_texts = [
  264. {
  265. 'content': '机器学习是人工智能的一个重要分支。',
  266. 'metadata': {'category': 'ML', 'author': 'John'}
  267. },
  268. {
  269. 'content': '深度学习在图像识别领域取得了显著成果。',
  270. 'metadata': {'category': 'Deep Learning', 'author': 'Jane'}
  271. },
  272. {
  273. 'content': '自然语言处理技术在聊天机器人中得到广泛应用。',
  274. 'metadata': {'category': 'NLP', 'author': 'Bob'}
  275. }
  276. ,
  277. {
  278. 'content': 'AI发展速度快,但需要更多的计算资源。',
  279. 'metadata': {'category': 'AI', 'author': 'Bob'}
  280. }
  281. ]
  282. milvus_manager.add_batch_documents(param=param, documents=sample_texts)
  283. # 搜索相似文本
  284. query = "人工智能相关的技术"
  285. similar_docs = milvus_manager.similarity_search(param, query, top_k=5)
  286. logger.info(f"Similar documents found-{len(similar_docs)}:")
  287. for doc in similar_docs:
  288. logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {doc['similarity']:.3f}")
  289. logger.info(f"{'=' *20}")
  290. # 带过滤条件的搜索
  291. filtered_docs = milvus_manager.retriever(
  292. param,
  293. query,
  294. top_k=5,
  295. filters={'category': 'AI'}
  296. )
  297. logger.info(f"\nFiltered similar documents-{len(filtered_docs)}:")
  298. for doc in filtered_docs:
  299. logger.info(f"ID: {doc['id']}, Text: {doc['text_content'][:50]}..., Similarity: {doc['similarity']:.3f}")