milvus_vector.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556
  1. import time
  2. from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility, Function
  3. from pymilvus.client.types import FunctionType
  4. from pymilvus import AnnSearchRequest, RRFRanker, WeightedRanker
  5. # from sentence_transformers import SentenceTransformer
  6. import numpy as np
  7. from typing import List, Dict, Any, Optional
  8. import json
  9. # 导入 LangChain Milvus 混合搜索相关包
  10. from langchain_milvus import Milvus, BM25BuiltInFunction
  11. from langchain_core.documents import Document
  12. from langchain_core.embeddings import Embeddings
  13. from foundation.infrastructure.config.config import config_handler
  14. from foundation.database.base.vector.base_vector import BaseVectorDB
  15. # 延迟导入logger和model_handler以避免循环依赖
  16. logger = None
  17. model_handler = None
  18. def _get_logger():
  19. """延迟导入logger以避免循环依赖"""
  20. global logger
  21. if logger is None:
  22. try:
  23. from foundation.observability.logger.loggering import review_logger as server_logger
  24. logger = server_logger
  25. except ImportError:
  26. # 如果导入失败,创建一个简单的logger替代品
  27. import logging
  28. logger = logging.getLogger(__name__)
  29. return logger
  30. def _get_model_handler():
  31. """延迟导入model_handler以避免循环依赖"""
  32. global model_handler
  33. if model_handler is None:
  34. try:
  35. from foundation.ai.models.model_handler import model_handler as mh
  36. model_handler = mh
  37. except ImportError:
  38. # 如果导入失败,返回None
  39. model_handler = None
  40. return model_handler
  41. class MilvusVectorManager(BaseVectorDB):
  42. def __init__(self):
  43. """
  44. 初始化 Milvus 连接
  45. """
  46. self.host = config_handler.get('milvus', 'MILVUS_HOST', 'localhost')
  47. self.port = int(config_handler.get('milvus', 'MILVUS_PORT', '19530'))
  48. self.milvus_db = config_handler.get('milvus', 'MILVUS_DB', 'default')
  49. self.user = config_handler.get('milvus', 'MILVUS_USER')
  50. self.password = config_handler.get('milvus', 'MILVUS_PASSWORD')
  51. # 初始化文本向量化模型
  52. mh = _get_model_handler()
  53. if mh:
  54. self.emdmodel = mh.get_embedding_model()
  55. else:
  56. raise ImportError("无法导入model_handler,无法初始化嵌入模型")
  57. # 缓存连接参数
  58. self.connection_args = {
  59. "uri": f"http://{self.host}:{self.port}",
  60. "user": self.user,
  61. "db_name": self.milvus_db
  62. }
  63. if self.password:
  64. self.connection_args["password"] = self.password
  65. # 连接到 Milvus
  66. self.connect()
  67. # 预创建常用的vectorstore连接,避免运行时竞争
  68. self._vectorstore_cache = {}
  69. self._collection_cache = {} # 缓存 pymilvus Collection 对象
  70. self._create_common_connections()
  71. def _create_common_connections(self):
  72. """预创建常用的vectorstore连接"""
  73. common_collections = [
  74. config_handler.get('rag_collections', 'ENTITY_COLLECTION', 'first_bfp_collection_entity')
  75. ]
  76. # 抑制 AsyncMilvusClient 的警告日志
  77. import logging
  78. original_level = logging.getLogger('pymilvus').level
  79. logging.getLogger('pymilvus').setLevel(logging.ERROR)
  80. try:
  81. for collection_name in common_collections:
  82. try:
  83. _get_logger().info(f"预创建vectorstore连接: {collection_name}")
  84. self._vectorstore_cache[collection_name] = Milvus(
  85. embedding_function=self.emdmodel,
  86. collection_name=collection_name,
  87. connection_args=self.connection_args,
  88. consistency_level="Strong",
  89. builtin_function=BM25BuiltInFunction(),
  90. vector_field=["dense", "sparse"]
  91. )
  92. _get_logger().info(f"成功预创建连接: {collection_name}")
  93. except Exception as e:
  94. _get_logger().error(f"预创建连接失败 {collection_name}: {e}")
  95. finally:
  96. logging.getLogger('pymilvus').setLevel(original_level)
  97. def _get_collection(self, collection_name: str) -> Collection:
  98. """获取缓存的 Collection 对象,避免重复创建"""
  99. if collection_name not in self._collection_cache:
  100. self._collection_cache[collection_name] = Collection(collection_name)
  101. return self._collection_cache[collection_name]
  102. def text_to_vector(self, text: str) -> List[float]:
  103. """
  104. 将文本转换为向量(重写基类方法,直接使用嵌入模型)
  105. """
  106. try:
  107. # 使用已有的嵌入模型
  108. embedding = self.emdmodel.embed_query(text)
  109. return embedding.tolist() if hasattr(embedding, 'tolist') else list(embedding)
  110. except Exception as e:
  111. _get_logger().error(f"Error converting text to vector: {e}")
  112. raise
  113. def connect(self):
  114. """连接到 Milvus 服务器"""
  115. try:
  116. connections.connect(
  117. alias="default",
  118. host=self.host,
  119. port=self.port,
  120. user=self.user,
  121. password=self.password,
  122. db_name=self.milvus_db
  123. )
  124. _get_logger().info(f"Connected to Milvus at {self.host}:{self.port}")
  125. except Exception as e:
  126. _get_logger().error(f"Failed to connect to Milvus: {e}")
  127. raise
  128. def create_collection(self, collection_name: str, dimension: int = 768,
  129. description: str = "Vector collection for text embeddings",
  130. allow_drop: bool = False):
  131. """
  132. 创建向量集合
  133. Args:
  134. collection_name: 集合名称
  135. dimension: 向量维度
  136. description: 集合描述
  137. allow_drop: 是否允许删除已有集合。默认 False,集合已存在时直接返回。
  138. """
  139. try:
  140. # 检查集合是否已存在
  141. if utility.has_collection(collection_name):
  142. if not allow_drop:
  143. _get_logger().info(f"Collection {collection_name} already exists, skip creation (allow_drop=False)")
  144. return
  145. _get_logger().warning(f"Collection {collection_name} already exists, dropping (allow_drop=True)")
  146. utility.drop_collection(collection_name)
  147. self._collection_cache.pop(collection_name, None)
  148. _get_logger().info(f"Collection '{collection_name}' dropped successfully")
  149. # 定义字段
  150. fields = [
  151. FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
  152. FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension),
  153. FieldSchema(name="text_content", dtype=DataType.VARCHAR, max_length=65535),
  154. FieldSchema(name="metadata", dtype=DataType.JSON),
  155. FieldSchema(name="created_at", dtype=DataType.INT64)
  156. ]
  157. # 创建集合模式
  158. schema = CollectionSchema(
  159. fields=fields,
  160. description=description
  161. )
  162. # 创建集合
  163. collection = Collection(
  164. name=collection_name,
  165. schema=schema
  166. )
  167. # 创建索引
  168. index_params = {
  169. "index_type": "IVF_FLAT",
  170. "metric_type": "COSINE",
  171. "params": {"nlist": 100}
  172. }
  173. collection.create_index(field_name="vector", index_params=index_params)
  174. _get_logger().info(f"Collection {collection_name} created successfully!")
  175. except Exception as e:
  176. _get_logger().error(f"Error creating collection: {e}")
  177. raise
  178. def add_document(self , param: Dict[str, Any] , document: Dict[str, Any]):
  179. """
  180. 插入单个文本及其向量
  181. """
  182. try:
  183. collection_name = param.get('collection_name')
  184. text = document.get('content')
  185. metadata = document.get('metadata')
  186. collection = self._get_collection(collection_name)
  187. created_at = None
  188. # 转换文本为向量
  189. embedding = self.text_to_vector(text)
  190. #_get_logger().info(f"Text converted to embedding:{isinstance(embedding, list)} ,{len(embedding)}")
  191. #_get_logger().info(f"Text converted to embedding:{embedding}")
  192. # 准备数据
  193. data = [
  194. [embedding], # embedding
  195. [text], # text_content
  196. [metadata or {}], # metadata
  197. [created_at or int(time.time())] # created_at
  198. ]
  199. _get_logger().info(f"Preparing to insert text_contents:{len(data[0])} ,{len(data[1])},{len(data[2])},{len(data[3])}")
  200. # 插入数据
  201. insert_result = collection.insert(data)
  202. collection.flush() # 确保数据被写入
  203. _get_logger().info(f"Text inserted with ID: {insert_result.primary_keys[0]}")
  204. return insert_result.primary_keys[0]
  205. except Exception as e:
  206. _get_logger().error(f"Error inserting text: {e}")
  207. return None
  208. def add_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]]):
  209. """
  210. 批量插入文本
  211. texts: [{'text': '...', 'metadata': {...}}, ...]
  212. """
  213. try:
  214. collection_name = param.get('collection_name')
  215. collection = self._get_collection(collection_name)
  216. text_contents = []
  217. embeddings = []
  218. metadatas = []
  219. timestamps = []
  220. for item in documents:
  221. text = item['content']
  222. metadata = item.get('metadata', {})
  223. # 转换文本为向量
  224. embedding = self.text_to_vector(text)
  225. text_contents.append(text)
  226. embeddings.append(embedding)
  227. metadatas.append(metadata)
  228. timestamps.append(int(time.time()))
  229. # 准备批量数据
  230. data = [embeddings, text_contents, metadatas, timestamps]
  231. #_get_logger().info(f"Preparing to insert text_contents:{len(text_contents)} ,{len(embeddings)},{len(metadatas)},{len(timestamps)}")
  232. # 批量插入
  233. insert_result = collection.insert(data)
  234. collection.flush() # 确保数据被写入
  235. _get_logger().info(f"Batch inserted {len(text_contents)} records, IDs: {insert_result.primary_keys}")
  236. return insert_result.primary_keys
  237. except Exception as e:
  238. _get_logger().error(f"Error batch inserting: {e}")
  239. return None
  240. def similarity_search(self, param: Dict[str, Any], query_text: str , min_score=0.5 ,
  241. top_k=5, filters: Dict[str, Any] = None):
  242. """
  243. 搜索相似文本
  244. """
  245. try:
  246. collection_name = param.get('collection_name')
  247. collection = self._get_collection(collection_name)
  248. # 加载集合到内存(如果还没有加载)
  249. collection.load()
  250. # 转换查询文本为向量
  251. query_embedding = self.text_to_vector(query_text)
  252. # 搜索参数
  253. search_params = {
  254. "metric_type": "COSINE",
  255. "params": {"nprobe": 10}
  256. }
  257. # 构建过滤表达式
  258. filter_expr = self._create_filter(filters)
  259. # 执行搜索
  260. results = collection.search(
  261. data=[query_embedding],
  262. anns_field="vector",
  263. param=search_params,
  264. limit=top_k,
  265. expr=filter_expr,
  266. output_fields=["text_content", "metadata"]
  267. )
  268. # 格式化结果
  269. formatted_results = []
  270. for hits in results:
  271. for hit in hits:
  272. formatted_results.append({
  273. 'id': hit.id,
  274. 'text_content': hit.entity.get('text_content'),
  275. 'text': hit.entity.get('text_content'),
  276. 'metadata': hit.entity.get('metadata'),
  277. 'distance': hit.distance,
  278. 'similarity': 1 - hit.distance # 转换为相似度
  279. })
  280. return formatted_results
  281. except Exception as e:
  282. _get_logger().error(f"Error searching: {e}")
  283. return []
  284. def retriever(self, param: Dict[str, Any], query_text: str,
  285. top_k: int = 5, filters: Dict[str, Any] = None):
  286. """
  287. 带过滤条件的相似搜索
  288. """
  289. try:
  290. collection_name = param.get('collection_name')
  291. collection = self._get_collection(collection_name)
  292. collection.load()
  293. query_embedding = self.text_to_vector(query_text)
  294. # 构建过滤表达式
  295. filter_expr = self._create_filter(filters)
  296. search_params = {
  297. "metric_type": "COSINE",
  298. "params": {"nprobe": 10}
  299. }
  300. results = collection.search(
  301. data=[query_embedding],
  302. anns_field="vector",
  303. param=search_params,
  304. limit=top_k,
  305. expr=filter_expr,
  306. output_fields=["text_content", "metadata"]
  307. )
  308. formatted_results = []
  309. for hits in results:
  310. for hit in hits:
  311. formatted_results.append({
  312. 'id': hit.id,
  313. 'text_content': hit.entity.get('text_content'),
  314. 'metadata': hit.entity.get('metadata'),
  315. 'distance': hit.distance,
  316. 'similarity': 1 - hit.distance
  317. })
  318. return formatted_results
  319. except Exception as e:
  320. _get_logger().error(f"Error searching with filter: {e}")
  321. return []
  322. def _create_filter(self, filters: Dict[str, Any]) -> str:
  323. """
  324. 创建过滤条件
  325. """
  326. # 构建过滤表达式
  327. filter_expr = ""
  328. if filters:
  329. conditions = []
  330. for key, value in filters.items():
  331. if isinstance(value, str):
  332. conditions.append(f'metadata["{key}"] == "{value}"')
  333. elif isinstance(value, (int, float)):
  334. conditions.append(f'metadata["{key}"] == {value}')
  335. else:
  336. conditions.append(f'metadata["{key}"] == "{json.dumps(value)}"')
  337. filter_expr = " and ".join(conditions)
  338. return filter_expr
  339. def create_hybrid_collection(self, collection_name: str, documents: List[Dict[str, Any]],
  340. drop_old: bool = False):
  341. """
  342. 创建支持混合搜索的集合
  343. Args:
  344. collection_name: 集合名称
  345. documents: 文档列表,格式: [{'content': '...', 'metadata': {...}}, ...]
  346. """
  347. try:
  348. # 构建连接参数 (参考 test_hybrid_v2.6.py)
  349. connection_args = {
  350. "uri": f"http://{self.host}:{self.port}",
  351. "user": self.user,
  352. "db_name": self.milvus_db
  353. }
  354. if self.password:
  355. connection_args["password"] = self.password
  356. langchain_docs = []
  357. for doc in documents:
  358. content = doc.get('content', '')
  359. metadata = doc.get('metadata', {})
  360. processed_metadata = self._process_metadata(metadata)
  361. langchain_doc = Document(page_content=content, metadata=processed_metadata)
  362. langchain_docs.append(langchain_doc)
  363. # 创建混合搜索向量存储
  364. vectorstore = Milvus.from_documents(
  365. documents=langchain_docs,
  366. embedding=self.emdmodel,
  367. builtin_function=BM25BuiltInFunction(),
  368. vector_field=["dense", "sparse"],
  369. connection_args=connection_args,
  370. collection_name=collection_name,
  371. consistency_level="Strong",
  372. drop_old=drop_old,
  373. )
  374. _get_logger().info(f"Created hybrid collection: {collection_name} with {len(documents)} documents")
  375. return vectorstore
  376. except Exception as e:
  377. _get_logger().error(f"Error creating hybrid collection: {e}")
  378. _get_logger().info("Falling back to traditional vector search")
  379. return None
  380. def hybrid_search(self, param: Dict[str, Any], query_text: str,
  381. top_k: int , ranker_type: str = "weighted",
  382. dense_weight: float = 0.7, sparse_weight: float = 0.3):
  383. """
  384. 混合搜索(参考 test_hybrid_v2.6.py 的实现)
  385. Args:
  386. param: 包含collection_name的参数字典
  387. query_text: 查询文本
  388. top_k: 返回结果数量
  389. ranker_type: 重排序类型 "weighted" 或 "rrf"
  390. dense_weight: 密集向量权重(当ranker_type="weighted"时使用)
  391. sparse_weight: 稀疏向量权重(当ranker_type="weighted"时使用)
  392. Returns:
  393. List[Dict]: 搜索结果列表
  394. """
  395. try:
  396. collection_name = param.get('collection_name')
  397. _get_logger().info(f"开始 hybrid_search, collection_name: {collection_name}")
  398. # 使用预创建的连接,避免运行时竞争
  399. if collection_name in self._vectorstore_cache:
  400. vectorstore = self._vectorstore_cache[collection_name]
  401. else:
  402. # 如果缓存中没有,创建新连接(降级方案)
  403. _get_logger().warning(f"缓存中未找到连接: {collection_name},创建新连接")
  404. # 抑制 AsyncMilvusClient 的警告日志
  405. import logging
  406. original_level = logging.getLogger('pymilvus').level
  407. logging.getLogger('pymilvus').setLevel(logging.ERROR)
  408. try:
  409. vectorstore = Milvus(
  410. embedding_function=self.emdmodel,
  411. collection_name=collection_name,
  412. connection_args=self.connection_args,
  413. consistency_level="Strong",
  414. builtin_function=BM25BuiltInFunction(),
  415. vector_field=["dense", "sparse"]
  416. )
  417. # 缓存新创建的连接
  418. self._vectorstore_cache[collection_name] = vectorstore
  419. finally:
  420. logging.getLogger('pymilvus').setLevel(original_level)
  421. _get_logger().info(f"混合召回topk: {top_k}")
  422. # 执行混合搜索,使用 similarity_search_with_score 获取评分
  423. if ranker_type == "weighted":
  424. results_with_scores = vectorstore.similarity_search_with_score(
  425. query=query_text,
  426. k=top_k,
  427. ranker_type="weighted",
  428. ranker_params={"weights": [dense_weight, sparse_weight]}
  429. )
  430. else: # rrf
  431. results_with_scores = vectorstore.similarity_search_with_score(
  432. query=query_text,
  433. k=top_k,
  434. ranker_type="rrf",
  435. ranker_params={"k": 60}
  436. )
  437. # 格式化结果,保持与其他搜索方法一致
  438. formatted_results = []
  439. for doc, score in results_with_scores:
  440. # score 值越小表示相似度越高,所以 similarity = 1 / (1 + score)
  441. # 或者使用其他转换方式,这里使用简单的转换
  442. similarity = 1 / (1 + score) if score >= 0 else 0
  443. formatted_results.append({
  444. 'id': doc.metadata.get('pk', 0),
  445. 'text_content': doc.page_content,
  446. 'metadata': doc.metadata,
  447. 'distance': float(score), # 使用真实的距离/评分
  448. 'similarity': float(similarity) # 转换为相似度
  449. })
  450. return formatted_results
  451. except Exception as e:
  452. _get_logger().error(f"Error in hybrid search: {e}")
  453. # 回退到传统的向量搜索
  454. _get_logger().info("Falling back to traditional vector search")
  455. return self.similarity_search(param, query_text, top_k=top_k)
  456. def _process_metadata(self,metadata):
  457. """处理 metadata:将 list 类型的 hierarchy 转换为 Milvus 支持的 string 类型"""
  458. processed_metadata = metadata.copy()
  459. if "hierarchy" in processed_metadata and isinstance(processed_metadata["hierarchy"], list):
  460. processed_metadata["hierarchy"] = " > ".join(processed_metadata["hierarchy"])
  461. for key, value in processed_metadata.items():
  462. if value is None:
  463. processed_metadata[key] = ""
  464. elif isinstance(value, dict):
  465. processed_metadata[key] = json.dumps(value, ensure_ascii=False)
  466. return processed_metadata