milvus_vector.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557
  1. import time
  2. from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility, Function
  3. from pymilvus.client.types import FunctionType
  4. from pymilvus import AnnSearchRequest, RRFRanker, WeightedRanker
  5. # from sentence_transformers import SentenceTransformer
  6. import numpy as np
  7. from typing import List, Dict, Any, Optional
  8. import json
  9. # 导入 LangChain Milvus 混合搜索相关包
  10. from langchain_milvus import Milvus, BM25BuiltInFunction
  11. from langchain_core.documents import Document
  12. from langchain_core.embeddings import Embeddings
  13. from foundation.infrastructure.config.config import config_handler
  14. from foundation.database.base.vector.base_vector import BaseVectorDB
  15. # 延迟导入logger和model_handler以避免循环依赖
  16. logger = None
  17. model_handler = None
  18. def _get_logger():
  19. """延迟导入logger以避免循环依赖"""
  20. global logger
  21. if logger is None:
  22. try:
  23. from foundation.observability.logger.loggering import server_logger
  24. logger = server_logger
  25. except ImportError:
  26. # 如果导入失败,创建一个简单的logger替代品
  27. import logging
  28. logger = logging.getLogger(__name__)
  29. return logger
  30. def _get_model_handler():
  31. """延迟导入model_handler以避免循环依赖"""
  32. global model_handler
  33. if model_handler is None:
  34. try:
  35. from foundation.ai.models.model_handler import model_handler as mh
  36. model_handler = mh
  37. except ImportError:
  38. # 如果导入失败,返回None
  39. model_handler = None
  40. return model_handler
  41. class MilvusVectorManager(BaseVectorDB):
  42. def __init__(self):
  43. """
  44. 初始化 Milvus 连接
  45. """
  46. self.host = config_handler.get('milvus', 'MILVUS_HOST', 'localhost')
  47. self.port = int(config_handler.get('milvus', 'MILVUS_PORT', '19530'))
  48. self.milvus_db = config_handler.get('milvus', 'MILVUS_DB', 'default')
  49. self.user = config_handler.get('milvus', 'MILVUS_USER')
  50. self.password = config_handler.get('milvus', 'MILVUS_PASSWORD')
  51. # 初始化文本向量化模型
  52. mh = _get_model_handler()
  53. if mh:
  54. self.emdmodel = mh.get_embedding_model()
  55. else:
  56. raise ImportError("无法导入model_handler,无法初始化嵌入模型")
  57. # 缓存连接参数
  58. self.connection_args = {
  59. "uri": f"http://{self.host}:{self.port}",
  60. "user": self.user,
  61. "db_name": "lq_db"
  62. }
  63. if self.password:
  64. self.connection_args["password"] = self.password
  65. # 连接到 Milvus
  66. self.connect()
  67. # 预创建常用的vectorstore连接,避免运行时竞争
  68. self._vectorstore_cache = {}
  69. self._create_common_connections()
  70. def _create_common_connections(self):
  71. """预创建常用的vectorstore连接"""
  72. common_collections = [
  73. "first_bfp_collection_entity",
  74. "first_bfp_collection_test"
  75. ]
  76. # 抑制 AsyncMilvusClient 的警告日志
  77. import logging
  78. original_level = logging.getLogger('pymilvus').level
  79. logging.getLogger('pymilvus').setLevel(logging.ERROR)
  80. try:
  81. for collection_name in common_collections:
  82. try:
  83. _get_logger().info(f"预创建vectorstore连接: {collection_name}")
  84. self._vectorstore_cache[collection_name] = Milvus(
  85. embedding_function=self.emdmodel,
  86. collection_name=collection_name,
  87. connection_args=self.connection_args,
  88. consistency_level="Strong",
  89. builtin_function=BM25BuiltInFunction(),
  90. vector_field=["dense", "sparse"]
  91. )
  92. _get_logger().info(f"成功预创建连接: {collection_name}")
  93. except Exception as e:
  94. _get_logger().error(f"预创建连接失败 {collection_name}: {e}")
  95. finally:
  96. logging.getLogger('pymilvus').setLevel(original_level)
  97. def text_to_vector(self, text: str) -> List[float]:
  98. """
  99. 将文本转换为向量(重写基类方法,直接使用嵌入模型)
  100. """
  101. try:
  102. # 使用已有的嵌入模型
  103. embedding = self.emdmodel.embed_query(text)
  104. return embedding.tolist() if hasattr(embedding, 'tolist') else list(embedding)
  105. except Exception as e:
  106. _get_logger().error(f"Error converting text to vector: {e}")
  107. raise
  108. def connect(self):
  109. """连接到 Milvus 服务器
  110. ,
  111. password=self.password
  112. alias="default",
  113. """
  114. try:
  115. connections.connect(
  116. alias="default",
  117. host=self.host,
  118. port=self.port,
  119. user=self.user,
  120. db_name="lq_db"
  121. )
  122. _get_logger().info(f"Connected to Milvus at {self.host}:{self.port}")
  123. except Exception as e:
  124. _get_logger().error(f"Failed to connect to Milvus: {e}")
  125. raise
  126. def create_collection(self, collection_name: str, dimension: int = 768,
  127. description: str = "Vector collection for text embeddings"):
  128. """
  129. 创建向量集合
  130. """
  131. try:
  132. # 检查集合是否已存在
  133. if utility.has_collection(collection_name):
  134. _get_logger().info(f"Collection {collection_name} already exists")
  135. utility.drop_collection(collection_name)
  136. _get_logger().info(f"Collection '{collection_name}' dropped successfully")
  137. # 定义字段
  138. fields = [
  139. FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
  140. FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension),
  141. FieldSchema(name="text_content", dtype=DataType.VARCHAR, max_length=65535),
  142. FieldSchema(name="metadata", dtype=DataType.JSON),
  143. FieldSchema(name="created_at", dtype=DataType.INT64)
  144. ]
  145. # 创建集合模式
  146. schema = CollectionSchema(
  147. fields=fields,
  148. description=description
  149. )
  150. # 创建集合
  151. collection = Collection(
  152. name=collection_name,
  153. schema=schema
  154. )
  155. # 创建索引
  156. index_params = {
  157. "index_type": "IVF_FLAT",
  158. "metric_type": "COSINE",
  159. "params": {"nlist": 100}
  160. }
  161. collection.create_index(field_name="vector", index_params=index_params)
  162. _get_logger().info(f"Collection {collection_name} created successfully!")
  163. except Exception as e:
  164. _get_logger().error(f"Error creating collection: {e}")
  165. raise
  166. def add_document(self , param: Dict[str, Any] , document: Dict[str, Any]):
  167. """
  168. 插入单个文本及其向量
  169. """
  170. try:
  171. collection_name = param.get('collection_name')
  172. text = document.get('content')
  173. metadata = document.get('metadata')
  174. collection = Collection(collection_name)
  175. created_at = None
  176. # 转换文本为向量
  177. embedding = self.text_to_vector(text)
  178. #_get_logger().info(f"Text converted to embedding:{isinstance(embedding, list)} ,{len(embedding)}")
  179. #_get_logger().info(f"Text converted to embedding:{embedding}")
  180. # 准备数据
  181. data = [
  182. [embedding], # embedding
  183. [text], # text_content
  184. [metadata or {}], # metadata
  185. [created_at or int(time.time())] # created_at
  186. ]
  187. _get_logger().info(f"Preparing to insert text_contents:{len(data[0])} ,{len(data[1])},{len(data[2])},{len(data[3])}")
  188. # 插入数据
  189. insert_result = collection.insert(data)
  190. collection.flush() # 确保数据被写入
  191. _get_logger().info(f"Text inserted with ID: {insert_result.primary_keys[0]}")
  192. return insert_result.primary_keys[0]
  193. except Exception as e:
  194. _get_logger().error(f"Error inserting text: {e}")
  195. return None
  196. def add_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]]):
  197. """
  198. 批量插入文本
  199. texts: [{'text': '...', 'metadata': {...}}, ...]
  200. """
  201. try:
  202. collection_name = param.get('collection_name')
  203. collection = Collection(collection_name)
  204. text_contents = []
  205. embeddings = []
  206. metadatas = []
  207. timestamps = []
  208. for item in documents:
  209. text = item['content']
  210. metadata = item.get('metadata', {})
  211. # 转换文本为向量
  212. embedding = self.text_to_vector(text)
  213. text_contents.append(text)
  214. embeddings.append(embedding)
  215. metadatas.append(metadata)
  216. timestamps.append(int(time.time()))
  217. # 准备批量数据
  218. data = [embeddings, text_contents, metadatas, timestamps]
  219. #_get_logger().info(f"Preparing to insert text_contents:{len(text_contents)} ,{len(embeddings)},{len(metadatas)},{len(timestamps)}")
  220. # 批量插入
  221. insert_result = collection.insert(data)
  222. collection.flush() # 确保数据被写入
  223. _get_logger().info(f"Batch inserted {len(text_contents)} records, IDs: {insert_result.primary_keys}")
  224. return insert_result.primary_keys
  225. except Exception as e:
  226. _get_logger().error(f"Error batch inserting: {e}")
  227. return None
  228. def similarity_search(self, param: Dict[str, Any], query_text: str , min_score=0.5 ,
  229. top_k=5, filters: Dict[str, Any] = None):
  230. """
  231. 搜索相似文本
  232. """
  233. try:
  234. collection_name = param.get('collection_name')
  235. collection = Collection(collection_name)
  236. # 加载集合到内存(如果还没有加载)
  237. collection.load()
  238. # 转换查询文本为向量
  239. query_embedding = self.text_to_vector(query_text)
  240. # 搜索参数
  241. search_params = {
  242. "metric_type": "COSINE",
  243. "params": {"nprobe": 10}
  244. }
  245. # 构建过滤表达式
  246. filter_expr = self._create_filter(filters)
  247. # 执行搜索
  248. results = collection.search(
  249. data=[query_embedding],
  250. anns_field="vector",
  251. param=search_params,
  252. limit=top_k,
  253. expr=filter_expr,
  254. output_fields=["text", "metadata"]
  255. )
  256. # 格式化结果
  257. formatted_results = []
  258. for hits in results:
  259. for hit in hits:
  260. formatted_results.append({
  261. 'id': hit.id,
  262. 'text_content': hit.entity.get('text'),
  263. 'text': hit.entity.get('text'), # 添加 text 字段以兼容现有代码
  264. 'metadata': hit.entity.get('metadata'),
  265. 'distance': hit.distance,
  266. 'similarity': 1 - hit.distance # 转换为相似度
  267. })
  268. return formatted_results
  269. except Exception as e:
  270. _get_logger().error(f"Error searching: {e}")
  271. return []
  272. def retriever(self, param: Dict[str, Any], query_text: str,
  273. top_k: int = 5, filters: Dict[str, Any] = None):
  274. """
  275. 带过滤条件的相似搜索
  276. """
  277. try:
  278. collection_name = param.get('collection_name')
  279. collection = Collection(collection_name)
  280. collection.load()
  281. query_embedding = self.text_to_vector(query_text)
  282. # 构建过滤表达式
  283. filter_expr = self._create_filter(filters)
  284. search_params = {
  285. "metric_type": "COSINE",
  286. "params": {"nprobe": 10}
  287. }
  288. results = collection.search(
  289. data=[query_embedding],
  290. anns_field="vector",
  291. param=search_params,
  292. limit=top_k,
  293. expr=filter_expr,
  294. output_fields=["text", "metadata"]
  295. )
  296. formatted_results = []
  297. for hits in results:
  298. for hit in hits:
  299. formatted_results.append({
  300. 'id': hit.id,
  301. 'text_content': hit.entity.get('text_content'),
  302. 'metadata': hit.entity.get('metadata'),
  303. 'distance': hit.distance,
  304. 'similarity': 1 - hit.distance
  305. })
  306. return formatted_results
  307. except Exception as e:
  308. _get_logger().error(f"Error searching with filter: {e}")
  309. return []
  310. def _create_filter(self, filters: Dict[str, Any]) -> str:
  311. """
  312. 创建过滤条件
  313. """
  314. # 构建过滤表达式
  315. filter_expr = ""
  316. if filters:
  317. conditions = []
  318. for key, value in filters.items():
  319. if isinstance(value, str):
  320. conditions.append(f'metadata["{key}"] == "{value}"')
  321. elif isinstance(value, (int, float)):
  322. conditions.append(f'metadata["{key}"] == {value}')
  323. else:
  324. conditions.append(f'metadata["{key}"] == "{json.dumps(value)}"')
  325. filter_expr = " and ".join(conditions)
  326. return filter_expr
  327. def create_hybrid_collection(self, collection_name: str, documents: List[Dict[str, Any]]):
  328. """
  329. 创建支持混合搜索的集合
  330. Args:
  331. collection_name: 集合名称
  332. documents: 文档列表,格式: [{'content': '...', 'metadata': {...}}, ...]
  333. """
  334. try:
  335. # 构建连接参数 (参考 test_hybrid_v2.6.py)
  336. connection_args = {
  337. "uri": f"http://{self.host}:{self.port}",
  338. "user": self.user,
  339. "db_name": "lq_db"
  340. }
  341. if self.password:
  342. connection_args["password"] = self.password
  343. langchain_docs = []
  344. for doc in documents:
  345. content = doc.get('content', '')
  346. metadata = doc.get('metadata', {})
  347. processed_metadata = self._process_metadata(doc)
  348. langchain_doc = Document(page_content=content, metadata=processed_metadata)
  349. langchain_docs.append(langchain_doc)
  350. # 创建混合搜索向量存储 (完全按照 test_hybrid_v2.6.py 的逻辑)
  351. vectorstore = Milvus.from_documents(
  352. documents=langchain_docs,
  353. embedding=self.emdmodel,
  354. builtin_function=BM25BuiltInFunction(),
  355. vector_field=["dense", "sparse"],
  356. connection_args=connection_args,
  357. collection_name=collection_name,
  358. consistency_level="Strong",
  359. drop_old=True,
  360. )
  361. _get_logger().info(f"Created hybrid collection: {collection_name} with {len(documents)} documents")
  362. return vectorstore
  363. except Exception as e:
  364. _get_logger().error(f"Error creating hybrid collection: {e}")
  365. _get_logger().info("Falling back to traditional vector search")
  366. return None
  367. def hybrid_search(self, param: Dict[str, Any], query_text: str,
  368. top_k: int , ranker_type: str = "weighted",
  369. dense_weight: float = 0.7, sparse_weight: float = 0.3):
  370. """
  371. 混合搜索(参考 test_hybrid_v2.6.py 的实现)
  372. Args:
  373. param: 包含collection_name的参数字典
  374. query_text: 查询文本
  375. top_k: 返回结果数量
  376. ranker_type: 重排序类型 "weighted" 或 "rrf"
  377. dense_weight: 密集向量权重(当ranker_type="weighted"时使用)
  378. sparse_weight: 稀疏向量权重(当ranker_type="weighted"时使用)
  379. Returns:
  380. List[Dict]: 搜索结果列表
  381. """
  382. try:
  383. collection_name = param.get('collection_name')
  384. logger.info(f"开始 hybrid_search, collection_name: {collection_name}")
  385. # 使用预创建的连接,避免运行时竞争
  386. if collection_name in self._vectorstore_cache:
  387. vectorstore = self._vectorstore_cache[collection_name]
  388. else:
  389. # 如果缓存中没有,创建新连接(降级方案)
  390. _get_logger().warning(f"缓存中未找到连接: {collection_name},创建新连接")
  391. # 抑制 AsyncMilvusClient 的警告日志
  392. import logging
  393. original_level = logging.getLogger('pymilvus').level
  394. logging.getLogger('pymilvus').setLevel(logging.ERROR)
  395. try:
  396. vectorstore = Milvus(
  397. embedding_function=self.emdmodel,
  398. collection_name=collection_name,
  399. connection_args=self.connection_args,
  400. consistency_level="Strong",
  401. builtin_function=BM25BuiltInFunction(),
  402. vector_field=["dense", "sparse"]
  403. )
  404. # 缓存新创建的连接
  405. self._vectorstore_cache[collection_name] = vectorstore
  406. finally:
  407. logging.getLogger('pymilvus').setLevel(original_level)
  408. _get_logger().info(f"混合召回topk: {top_k}")
  409. # 执行混合搜索,使用 similarity_search_with_score 获取评分
  410. if ranker_type == "weighted":
  411. results_with_scores = vectorstore.similarity_search_with_score(
  412. query=query_text,
  413. k=top_k,
  414. ranker_type="weighted",
  415. ranker_params={"weights": [dense_weight, sparse_weight]}
  416. )
  417. else: # rrf
  418. results_with_scores = vectorstore.similarity_search_with_score(
  419. query=query_text,
  420. k=top_k,
  421. ranker_type="rrf",
  422. ranker_params={"k": 60}
  423. )
  424. # 格式化结果,保持与其他搜索方法一致
  425. formatted_results = []
  426. for doc, score in results_with_scores:
  427. # score 值越小表示相似度越高,所以 similarity = 1 / (1 + score)
  428. # 或者使用其他转换方式,这里使用简单的转换
  429. similarity = 1 / (1 + score) if score >= 0 else 0
  430. formatted_results.append({
  431. 'id': doc.metadata.get('pk', 0),
  432. 'text_content': doc.page_content,
  433. 'metadata': doc.metadata,
  434. 'distance': float(score), # 使用真实的距离/评分
  435. 'similarity': float(similarity) # 转换为相似度
  436. })
  437. # # 记录每个结果的评分信息
  438. # metadata = doc.metadata.get('metadata', {})
  439. # title = 'N/A'
  440. # if isinstance(metadata, str):
  441. # try:
  442. # import json
  443. # inner_metadata = json.loads(metadata)
  444. # title = inner_metadata.get('title', 'N/A')
  445. # except:
  446. # pass
  447. # else:
  448. # title = metadata.get('title', 'N/A')
  449. # _get_logger().info(f"混合搜索评分: 标题='{title}', 距离={score:.4f}, 相似度={similarity:.4f}")
  450. # _get_logger().info(f"Hybrid search returned {len(formatted_results)} results")
  451. return formatted_results
  452. except Exception as e:
  453. _get_logger().error(f"Error in hybrid search: {e}")
  454. # 回退到传统的向量搜索
  455. _get_logger().info("Falling back to traditional vector search")
  456. return self.similarity_search(param, query_text, top_k=top_k)
  457. def _process_metadata(self,metadata):
  458. """处理 metadata:将 list 类型的 hierarchy 转换为 Milvus 支持的 string 类型"""
  459. processed_metadata = metadata.copy()
  460. if "hierarchy" in processed_metadata and isinstance(processed_metadata["hierarchy"], list):
  461. processed_metadata["hierarchy"] = " > ".join(processed_metadata["hierarchy"])
  462. for key, value in processed_metadata.items():
  463. if value is None:
  464. processed_metadata[key] = ""
  465. elif isinstance(value, dict):
  466. processed_metadata[key] = json.dumps(value, ensure_ascii=False)
  467. return processed_metadata