milvus_vector.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540
  1. import time
  2. from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility, Function
  3. from pymilvus.client.types import FunctionType
  4. from pymilvus import AnnSearchRequest, RRFRanker, WeightedRanker
  5. # from sentence_transformers import SentenceTransformer
  6. import numpy as np
  7. from typing import List, Dict, Any, Optional
  8. import json
  9. # 导入 LangChain Milvus 混合搜索相关包
  10. from langchain_milvus import Milvus, BM25BuiltInFunction
  11. from langchain_core.documents import Document
  12. from langchain_core.embeddings import Embeddings
  13. from foundation.infrastructure.config.config import config_handler
  14. from foundation.database.base.vector.base_vector import BaseVectorDB
  15. # 延迟导入logger和model_handler以避免循环依赖
  16. logger = None
  17. model_handler = None
  18. def _get_logger():
  19. """延迟导入logger以避免循环依赖"""
  20. global logger
  21. if logger is None:
  22. try:
  23. from foundation.observability.logger.loggering import server_logger
  24. logger = server_logger
  25. except ImportError:
  26. # 如果导入失败,创建一个简单的logger替代品
  27. import logging
  28. logger = logging.getLogger(__name__)
  29. return logger
  30. def _get_model_handler():
  31. """延迟导入model_handler以避免循环依赖"""
  32. global model_handler
  33. if model_handler is None:
  34. try:
  35. from foundation.ai.models.model_handler import model_handler as mh
  36. model_handler = mh
  37. except ImportError:
  38. # 如果导入失败,返回None
  39. model_handler = None
  40. return model_handler
  41. class MilvusVectorManager(BaseVectorDB):
  42. def __init__(self):
  43. """
  44. 初始化 Milvus 连接
  45. """
  46. self.host = config_handler.get('milvus', 'MILVUS_HOST', 'localhost')
  47. self.port = int(config_handler.get('milvus', 'MILVUS_PORT', '19530'))
  48. self.milvus_db = config_handler.get('milvus', 'MILVUS_DB', 'default')
  49. self.user = config_handler.get('milvus', 'MILVUS_USER')
  50. self.password = config_handler.get('milvus', 'MILVUS_PASSWORD')
  51. # 初始化文本向量化模型
  52. mh = _get_model_handler()
  53. if mh:
  54. self.emdmodel = mh.get_embedding_model()
  55. else:
  56. raise ImportError("无法导入model_handler,无法初始化嵌入模型")
  57. # 缓存连接参数
  58. self.connection_args = {
  59. "uri": f"http://{self.host}:{self.port}",
  60. "user": self.user,
  61. "db_name": "lq_db"
  62. }
  63. if self.password:
  64. self.connection_args["password"] = self.password
  65. # 连接到 Milvus
  66. self.connect()
  67. # 预创建常用的vectorstore连接,避免运行时竞争
  68. self._vectorstore_cache = {}
  69. self._create_common_connections()
  70. def _create_common_connections(self):
  71. """预创建常用的vectorstore连接"""
  72. common_collections = [
  73. "first_bfp_collection_entity",
  74. "first_bfp_collection_test"
  75. ]
  76. for collection_name in common_collections:
  77. try:
  78. _get_logger().info(f"预创建vectorstore连接: {collection_name}")
  79. self._vectorstore_cache[collection_name] = Milvus(
  80. embedding_function=self.emdmodel,
  81. collection_name=collection_name,
  82. connection_args=self.connection_args,
  83. consistency_level="Strong",
  84. builtin_function=BM25BuiltInFunction(),
  85. vector_field=["dense", "sparse"]
  86. )
  87. _get_logger().info(f"成功预创建连接: {collection_name}")
  88. except Exception as e:
  89. _get_logger().error(f"预创建连接失败 {collection_name}: {e}")
  90. def text_to_vector(self, text: str) -> List[float]:
  91. """
  92. 将文本转换为向量(重写基类方法,直接使用嵌入模型)
  93. """
  94. try:
  95. # 使用已有的嵌入模型
  96. embedding = self.emdmodel.embed_query(text)
  97. return embedding.tolist() if hasattr(embedding, 'tolist') else list(embedding)
  98. except Exception as e:
  99. _get_logger().error(f"Error converting text to vector: {e}")
  100. raise
  101. def connect(self):
  102. """连接到 Milvus 服务器
  103. ,
  104. password=self.password
  105. alias="default",
  106. """
  107. try:
  108. connections.connect(
  109. alias="default",
  110. host=self.host,
  111. port=self.port,
  112. user=self.user,
  113. db_name="lq_db"
  114. )
  115. _get_logger().info(f"Connected to Milvus at {self.host}:{self.port}")
  116. except Exception as e:
  117. _get_logger().error(f"Failed to connect to Milvus: {e}")
  118. raise
  119. def create_collection(self, collection_name: str, dimension: int = 768,
  120. description: str = "Vector collection for text embeddings"):
  121. """
  122. 创建向量集合
  123. """
  124. try:
  125. # 检查集合是否已存在
  126. if utility.has_collection(collection_name):
  127. _get_logger().info(f"Collection {collection_name} already exists")
  128. utility.drop_collection(collection_name)
  129. _get_logger().info(f"Collection '{collection_name}' dropped successfully")
  130. # 定义字段
  131. fields = [
  132. FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
  133. FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension),
  134. FieldSchema(name="text_content", dtype=DataType.VARCHAR, max_length=65535),
  135. FieldSchema(name="metadata", dtype=DataType.JSON),
  136. FieldSchema(name="created_at", dtype=DataType.INT64)
  137. ]
  138. # 创建集合模式
  139. schema = CollectionSchema(
  140. fields=fields,
  141. description=description
  142. )
  143. # 创建集合
  144. collection = Collection(
  145. name=collection_name,
  146. schema=schema
  147. )
  148. # 创建索引
  149. index_params = {
  150. "index_type": "IVF_FLAT",
  151. "metric_type": "COSINE",
  152. "params": {"nlist": 100}
  153. }
  154. collection.create_index(field_name="vector", index_params=index_params)
  155. _get_logger().info(f"Collection {collection_name} created successfully!")
  156. except Exception as e:
  157. _get_logger().error(f"Error creating collection: {e}")
  158. raise
  159. def add_document(self , param: Dict[str, Any] , document: Dict[str, Any]):
  160. """
  161. 插入单个文本及其向量
  162. """
  163. try:
  164. collection_name = param.get('collection_name')
  165. text = document.get('content')
  166. metadata = document.get('metadata')
  167. collection = Collection(collection_name)
  168. created_at = None
  169. # 转换文本为向量
  170. embedding = self.text_to_vector(text)
  171. #_get_logger().info(f"Text converted to embedding:{isinstance(embedding, list)} ,{len(embedding)}")
  172. #_get_logger().info(f"Text converted to embedding:{embedding}")
  173. # 准备数据
  174. data = [
  175. [embedding], # embedding
  176. [text], # text_content
  177. [metadata or {}], # metadata
  178. [created_at or int(time.time())] # created_at
  179. ]
  180. _get_logger().info(f"Preparing to insert text_contents:{len(data[0])} ,{len(data[1])},{len(data[2])},{len(data[3])}")
  181. # 插入数据
  182. insert_result = collection.insert(data)
  183. collection.flush() # 确保数据被写入
  184. _get_logger().info(f"Text inserted with ID: {insert_result.primary_keys[0]}")
  185. return insert_result.primary_keys[0]
  186. except Exception as e:
  187. _get_logger().error(f"Error inserting text: {e}")
  188. return None
  189. def add_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]]):
  190. """
  191. 批量插入文本
  192. texts: [{'text': '...', 'metadata': {...}}, ...]
  193. """
  194. try:
  195. collection_name = param.get('collection_name')
  196. collection = Collection(collection_name)
  197. text_contents = []
  198. embeddings = []
  199. metadatas = []
  200. timestamps = []
  201. for item in documents:
  202. text = item['content']
  203. metadata = item.get('metadata', {})
  204. # 转换文本为向量
  205. embedding = self.text_to_vector(text)
  206. text_contents.append(text)
  207. embeddings.append(embedding)
  208. metadatas.append(metadata)
  209. timestamps.append(int(time.time()))
  210. # 准备批量数据
  211. data = [embeddings, text_contents, metadatas, timestamps]
  212. #_get_logger().info(f"Preparing to insert text_contents:{len(text_contents)} ,{len(embeddings)},{len(metadatas)},{len(timestamps)}")
  213. # 批量插入
  214. insert_result = collection.insert(data)
  215. collection.flush() # 确保数据被写入
  216. _get_logger().info(f"Batch inserted {len(text_contents)} records, IDs: {insert_result.primary_keys}")
  217. return insert_result.primary_keys
  218. except Exception as e:
  219. _get_logger().error(f"Error batch inserting: {e}")
  220. return None
  221. def similarity_search(self, param: Dict[str, Any], query_text: str , min_score=0.5 ,
  222. top_k=5, filters: Dict[str, Any] = None):
  223. """
  224. 搜索相似文本
  225. """
  226. try:
  227. collection_name = param.get('collection_name')
  228. collection = Collection(collection_name)
  229. # 加载集合到内存(如果还没有加载)
  230. collection.load()
  231. # 转换查询文本为向量
  232. query_embedding = self.text_to_vector(query_text)
  233. # 搜索参数
  234. search_params = {
  235. "metric_type": "COSINE",
  236. "params": {"nprobe": 10}
  237. }
  238. # 构建过滤表达式
  239. filter_expr = self._create_filter(filters)
  240. # 执行搜索
  241. results = collection.search(
  242. data=[query_embedding],
  243. anns_field="vector",
  244. param=search_params,
  245. limit=top_k,
  246. expr=filter_expr,
  247. output_fields=["text", "metadata"]
  248. )
  249. # 格式化结果
  250. formatted_results = []
  251. for hits in results:
  252. for hit in hits:
  253. formatted_results.append({
  254. 'id': hit.id,
  255. 'text_content': hit.entity.get('text'),
  256. 'text': hit.entity.get('text'), # 添加 text 字段以兼容现有代码
  257. 'metadata': hit.entity.get('metadata'),
  258. 'distance': hit.distance,
  259. 'similarity': 1 - hit.distance # 转换为相似度
  260. })
  261. return formatted_results
  262. except Exception as e:
  263. _get_logger().error(f"Error searching: {e}")
  264. return []
  265. def retriever(self, param: Dict[str, Any], query_text: str,
  266. top_k: int = 5, filters: Dict[str, Any] = None):
  267. """
  268. 带过滤条件的相似搜索
  269. """
  270. try:
  271. collection_name = param.get('collection_name')
  272. collection = Collection(collection_name)
  273. collection.load()
  274. query_embedding = self.text_to_vector(query_text)
  275. # 构建过滤表达式
  276. filter_expr = self._create_filter(filters)
  277. search_params = {
  278. "metric_type": "COSINE",
  279. "params": {"nprobe": 10}
  280. }
  281. results = collection.search(
  282. data=[query_embedding],
  283. anns_field="vector",
  284. param=search_params,
  285. limit=top_k,
  286. expr=filter_expr,
  287. output_fields=["text", "metadata"]
  288. )
  289. formatted_results = []
  290. for hits in results:
  291. for hit in hits:
  292. formatted_results.append({
  293. 'id': hit.id,
  294. 'text_content': hit.entity.get('text_content'),
  295. 'metadata': hit.entity.get('metadata'),
  296. 'distance': hit.distance,
  297. 'similarity': 1 - hit.distance
  298. })
  299. return formatted_results
  300. except Exception as e:
  301. _get_logger().error(f"Error searching with filter: {e}")
  302. return []
  303. def _create_filter(self, filters: Dict[str, Any]) -> str:
  304. """
  305. 创建过滤条件
  306. """
  307. # 构建过滤表达式
  308. filter_expr = ""
  309. if filters:
  310. conditions = []
  311. for key, value in filters.items():
  312. if isinstance(value, str):
  313. conditions.append(f'metadata["{key}"] == "{value}"')
  314. elif isinstance(value, (int, float)):
  315. conditions.append(f'metadata["{key}"] == {value}')
  316. else:
  317. conditions.append(f'metadata["{key}"] == "{json.dumps(value)}"')
  318. filter_expr = " and ".join(conditions)
  319. return filter_expr
  320. def create_hybrid_collection(self, collection_name: str, documents: List[Dict[str, Any]]):
  321. """
  322. 创建支持混合搜索的集合
  323. Args:
  324. collection_name: 集合名称
  325. documents: 文档列表,格式: [{'content': '...', 'metadata': {...}}, ...]
  326. """
  327. try:
  328. # 构建连接参数 (参考 test_hybrid_v2.6.py)
  329. connection_args = {
  330. "uri": f"http://{self.host}:{self.port}",
  331. "user": self.user,
  332. "db_name": "lq_db"
  333. }
  334. if self.password:
  335. connection_args["password"] = self.password
  336. langchain_docs = []
  337. for doc in documents:
  338. content = doc.get('content', '')
  339. metadata = doc.get('metadata', {})
  340. processed_metadata = self._process_metadata(doc)
  341. langchain_doc = Document(page_content=content, metadata=processed_metadata)
  342. langchain_docs.append(langchain_doc)
  343. # 创建混合搜索向量存储 (完全按照 test_hybrid_v2.6.py 的逻辑)
  344. vectorstore = Milvus.from_documents(
  345. documents=langchain_docs,
  346. embedding=self.emdmodel,
  347. builtin_function=BM25BuiltInFunction(),
  348. vector_field=["dense", "sparse"],
  349. connection_args=connection_args,
  350. collection_name=collection_name,
  351. consistency_level="Strong",
  352. drop_old=True,
  353. )
  354. _get_logger().info(f"Created hybrid collection: {collection_name} with {len(documents)} documents")
  355. return vectorstore
  356. except Exception as e:
  357. _get_logger().error(f"Error creating hybrid collection: {e}")
  358. _get_logger().info("Falling back to traditional vector search")
  359. return None
  360. def hybrid_search(self, param: Dict[str, Any], query_text: str,
  361. top_k: int , ranker_type: str = "weighted",
  362. dense_weight: float = 0.7, sparse_weight: float = 0.3):
  363. """
  364. 混合搜索(参考 test_hybrid_v2.6.py 的实现)
  365. Args:
  366. param: 包含collection_name的参数字典
  367. query_text: 查询文本
  368. top_k: 返回结果数量
  369. ranker_type: 重排序类型 "weighted" 或 "rrf"
  370. dense_weight: 密集向量权重(当ranker_type="weighted"时使用)
  371. sparse_weight: 稀疏向量权重(当ranker_type="weighted"时使用)
  372. Returns:
  373. List[Dict]: 搜索结果列表
  374. """
  375. try:
  376. collection_name = param.get('collection_name')
  377. logger.info(f"开始 hybrid_search, collection_name: {collection_name}")
  378. # 使用预创建的连接,避免运行时竞争
  379. if collection_name in self._vectorstore_cache:
  380. vectorstore = self._vectorstore_cache[collection_name]
  381. else:
  382. # 如果缓存中没有,创建新连接(降级方案)
  383. _get_logger().warning(f"缓存中未找到连接: {collection_name},创建新连接")
  384. vectorstore = Milvus(
  385. embedding_function=self.emdmodel,
  386. collection_name=collection_name,
  387. connection_args=self.connection_args,
  388. consistency_level="Strong",
  389. builtin_function=BM25BuiltInFunction(),
  390. vector_field=["dense", "sparse"]
  391. )
  392. # 缓存新创建的连接
  393. self._vectorstore_cache[collection_name] = vectorstore
  394. _get_logger().info(f"混合召回topk: {top_k}")
  395. # 执行混合搜索,使用 similarity_search_with_score 获取评分
  396. if ranker_type == "weighted":
  397. results_with_scores = vectorstore.similarity_search_with_score(
  398. query=query_text,
  399. k=top_k,
  400. ranker_type="weighted",
  401. ranker_params={"weights": [dense_weight, sparse_weight]}
  402. )
  403. else: # rrf
  404. results_with_scores = vectorstore.similarity_search_with_score(
  405. query=query_text,
  406. k=top_k,
  407. ranker_type="rrf",
  408. ranker_params={"k": 60}
  409. )
  410. # 格式化结果,保持与其他搜索方法一致
  411. formatted_results = []
  412. for doc, score in results_with_scores:
  413. # score 值越小表示相似度越高,所以 similarity = 1 / (1 + score)
  414. # 或者使用其他转换方式,这里使用简单的转换
  415. similarity = 1 / (1 + score) if score >= 0 else 0
  416. formatted_results.append({
  417. 'id': doc.metadata.get('pk', 0),
  418. 'text_content': doc.page_content,
  419. 'metadata': doc.metadata,
  420. 'distance': float(score), # 使用真实的距离/评分
  421. 'similarity': float(similarity) # 转换为相似度
  422. })
  423. # # 记录每个结果的评分信息
  424. # metadata = doc.metadata.get('metadata', {})
  425. # title = 'N/A'
  426. # if isinstance(metadata, str):
  427. # try:
  428. # import json
  429. # inner_metadata = json.loads(metadata)
  430. # title = inner_metadata.get('title', 'N/A')
  431. # except:
  432. # pass
  433. # else:
  434. # title = metadata.get('title', 'N/A')
  435. # _get_logger().info(f"混合搜索评分: 标题='{title}', 距离={score:.4f}, 相似度={similarity:.4f}")
  436. # _get_logger().info(f"Hybrid search returned {len(formatted_results)} results")
  437. return formatted_results
  438. except Exception as e:
  439. _get_logger().error(f"Error in hybrid search: {e}")
  440. # 回退到传统的向量搜索
  441. _get_logger().info("Falling back to traditional vector search")
  442. return self.similarity_search(param, query_text, top_k=top_k)
  443. def _process_metadata(self,metadata):
  444. """处理 metadata:将 list 类型的 hierarchy 转换为 Milvus 支持的 string 类型"""
  445. processed_metadata = metadata.copy()
  446. if "hierarchy" in processed_metadata and isinstance(processed_metadata["hierarchy"], list):
  447. processed_metadata["hierarchy"] = " > ".join(processed_metadata["hierarchy"])
  448. for key, value in processed_metadata.items():
  449. if value is None:
  450. processed_metadata[key] = ""
  451. elif isinstance(value, dict):
  452. processed_metadata[key] = json.dumps(value, ensure_ascii=False)
  453. return processed_metadata