milvus_vector.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478
  1. import time
  2. from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility, Function
  3. from pymilvus.client.types import FunctionType
  4. from pymilvus import AnnSearchRequest, RRFRanker, WeightedRanker
  5. # from sentence_transformers import SentenceTransformer
  6. import numpy as np
  7. from typing import List, Dict, Any, Optional
  8. import json
  9. # 导入 LangChain Milvus 混合搜索相关包
  10. from langchain_milvus import Milvus, BM25BuiltInFunction
  11. from langchain_core.documents import Document
  12. from langchain_core.embeddings import Embeddings
  13. from foundation.infrastructure.config.config import config_handler
  14. from foundation.database.base.vector.base_vector import BaseVectorDB
  15. # 延迟导入logger和model_handler以避免循环依赖
  16. logger = None
  17. model_handler = None
  18. def _get_logger():
  19. """延迟导入logger以避免循环依赖"""
  20. global logger
  21. if logger is None:
  22. try:
  23. from foundation.observability.logger.loggering import server_logger
  24. logger = server_logger
  25. except ImportError:
  26. # 如果导入失败,创建一个简单的logger替代品
  27. import logging
  28. logger = logging.getLogger(__name__)
  29. return logger
  30. def _get_model_handler():
  31. """延迟导入model_handler以避免循环依赖"""
  32. global model_handler
  33. if model_handler is None:
  34. try:
  35. from foundation.ai.models.model_handler import model_handler as mh
  36. model_handler = mh
  37. except ImportError:
  38. # 如果导入失败,返回None
  39. model_handler = None
  40. return model_handler
  41. class MilvusVectorManager(BaseVectorDB):
  42. def __init__(self):
  43. """
  44. 初始化 Milvus 连接
  45. """
  46. self.host = config_handler.get('milvus', 'MILVUS_HOST', 'localhost')
  47. self.port = int(config_handler.get('milvus', 'MILVUS_PORT', '19530'))
  48. self.milvus_db = config_handler.get('milvus', 'MILVUS_DB', 'default')
  49. self.user = config_handler.get('milvus', 'MILVUS_USER')
  50. self.password = config_handler.get('milvus', 'MILVUS_PASSWORD')
  51. # 初始化文本向量化模型
  52. mh = _get_model_handler()
  53. if mh:
  54. self.emdmodel = mh._get_lq_qwen3_8b_emd()
  55. else:
  56. raise ImportError("无法导入model_handler,无法初始化嵌入模型")
  57. # 连接到 Milvus
  58. self.connect()
  59. def text_to_vector(self, text: str) -> List[float]:
  60. """
  61. 将文本转换为向量(重写基类方法,直接使用嵌入模型)
  62. """
  63. try:
  64. # 使用已有的嵌入模型
  65. embedding = self.emdmodel.embed_query(text)
  66. return embedding.tolist() if hasattr(embedding, 'tolist') else list(embedding)
  67. except Exception as e:
  68. _get_logger().error(f"Error converting text to vector: {e}")
  69. raise
  70. def connect(self):
  71. """连接到 Milvus 服务器
  72. ,
  73. password=self.password
  74. alias="default",
  75. """
  76. try:
  77. connections.connect(
  78. alias="default",
  79. host=self.host,
  80. port=self.port,
  81. user=self.user,
  82. db_name="lq_db"
  83. )
  84. _get_logger().info(f"Connected to Milvus at {self.host}:{self.port}")
  85. except Exception as e:
  86. _get_logger().error(f"Failed to connect to Milvus: {e}")
  87. raise
  88. def create_collection(self, collection_name: str, dimension: int = 768,
  89. description: str = "Vector collection for text embeddings"):
  90. """
  91. 创建向量集合
  92. """
  93. try:
  94. # 检查集合是否已存在
  95. if utility.has_collection(collection_name):
  96. _get_logger().info(f"Collection {collection_name} already exists")
  97. utility.drop_collection(collection_name)
  98. _get_logger().info(f"Collection '{collection_name}' dropped successfully")
  99. # 定义字段
  100. fields = [
  101. FieldSchema(name="id", dtype=DataType.INT64, is_primary=True, auto_id=True),
  102. FieldSchema(name="vector", dtype=DataType.FLOAT_VECTOR, dim=dimension),
  103. FieldSchema(name="text_content", dtype=DataType.VARCHAR, max_length=65535),
  104. FieldSchema(name="metadata", dtype=DataType.JSON),
  105. FieldSchema(name="created_at", dtype=DataType.INT64)
  106. ]
  107. # 创建集合模式
  108. schema = CollectionSchema(
  109. fields=fields,
  110. description=description
  111. )
  112. # 创建集合
  113. collection = Collection(
  114. name=collection_name,
  115. schema=schema
  116. )
  117. # 创建索引
  118. index_params = {
  119. "index_type": "IVF_FLAT",
  120. "metric_type": "COSINE",
  121. "params": {"nlist": 100}
  122. }
  123. collection.create_index(field_name="vector", index_params=index_params)
  124. _get_logger().info(f"Collection {collection_name} created successfully!")
  125. except Exception as e:
  126. _get_logger().error(f"Error creating collection: {e}")
  127. raise
  128. def add_document(self , param: Dict[str, Any] , document: Dict[str, Any]):
  129. """
  130. 插入单个文本及其向量
  131. """
  132. try:
  133. collection_name = param.get('collection_name')
  134. text = document.get('content')
  135. metadata = document.get('metadata')
  136. collection = Collection(collection_name)
  137. created_at = None
  138. # 转换文本为向量
  139. embedding = self.text_to_vector(text)
  140. #_get_logger().info(f"Text converted to embedding:{isinstance(embedding, list)} ,{len(embedding)}")
  141. #_get_logger().info(f"Text converted to embedding:{embedding}")
  142. # 准备数据
  143. data = [
  144. [embedding], # embedding
  145. [text], # text_content
  146. [metadata or {}], # metadata
  147. [created_at or int(time.time())] # created_at
  148. ]
  149. _get_logger().info(f"Preparing to insert text_contents:{len(data[0])} ,{len(data[1])},{len(data[2])},{len(data[3])}")
  150. # 插入数据
  151. insert_result = collection.insert(data)
  152. collection.flush() # 确保数据被写入
  153. _get_logger().info(f"Text inserted with ID: {insert_result.primary_keys[0]}")
  154. return insert_result.primary_keys[0]
  155. except Exception as e:
  156. _get_logger().error(f"Error inserting text: {e}")
  157. return None
  158. def add_batch_documents(self , param: Dict[str, Any] , documents: List[Dict[str, Any]]):
  159. """
  160. 批量插入文本
  161. texts: [{'text': '...', 'metadata': {...}}, ...]
  162. """
  163. try:
  164. collection_name = param.get('collection_name')
  165. collection = Collection(collection_name)
  166. text_contents = []
  167. embeddings = []
  168. metadatas = []
  169. timestamps = []
  170. for item in documents:
  171. text = item['content']
  172. metadata = item.get('metadata', {})
  173. # 转换文本为向量
  174. embedding = self.text_to_vector(text)
  175. text_contents.append(text)
  176. embeddings.append(embedding)
  177. metadatas.append(metadata)
  178. timestamps.append(int(time.time()))
  179. # 准备批量数据
  180. data = [embeddings, text_contents, metadatas, timestamps]
  181. #_get_logger().info(f"Preparing to insert text_contents:{len(text_contents)} ,{len(embeddings)},{len(metadatas)},{len(timestamps)}")
  182. # 批量插入
  183. insert_result = collection.insert(data)
  184. collection.flush() # 确保数据被写入
  185. _get_logger().info(f"Batch inserted {len(text_contents)} records, IDs: {insert_result.primary_keys}")
  186. return insert_result.primary_keys
  187. except Exception as e:
  188. _get_logger().error(f"Error batch inserting: {e}")
  189. return None
  190. def similarity_search(self, param: Dict[str, Any], query_text: str , min_score=0.5 ,
  191. top_k=5, filters: Dict[str, Any] = None):
  192. """
  193. 搜索相似文本
  194. """
  195. try:
  196. collection_name = param.get('collection_name')
  197. collection = Collection(collection_name)
  198. # 加载集合到内存(如果还没有加载)
  199. collection.load()
  200. # 转换查询文本为向量
  201. query_embedding = self.text_to_vector(query_text)
  202. # 搜索参数
  203. search_params = {
  204. "metric_type": "COSINE",
  205. "params": {"nprobe": 10}
  206. }
  207. # 构建过滤表达式
  208. filter_expr = self._create_filter(filters)
  209. # 执行搜索
  210. results = collection.search(
  211. data=[query_embedding],
  212. anns_field="vector",
  213. param=search_params,
  214. limit=top_k,
  215. expr=filter_expr,
  216. output_fields=["text_content", "metadata"]
  217. )
  218. # 格式化结果
  219. formatted_results = []
  220. for hits in results:
  221. for hit in hits:
  222. formatted_results.append({
  223. 'id': hit.id,
  224. 'text_content': hit.entity.get('text_content'),
  225. 'metadata': hit.entity.get('metadata'),
  226. 'distance': hit.distance,
  227. 'similarity': 1 - hit.distance # 转换为相似度
  228. })
  229. return formatted_results
  230. except Exception as e:
  231. _get_logger().error(f"Error searching: {e}")
  232. return []
  233. def retriever(self, param: Dict[str, Any], query_text: str,
  234. top_k: int = 5, filters: Dict[str, Any] = None):
  235. """
  236. 带过滤条件的相似搜索
  237. """
  238. try:
  239. collection_name = param.get('collection_name')
  240. collection = Collection(collection_name)
  241. collection.load()
  242. query_embedding = self.text_to_vector(query_text)
  243. # 构建过滤表达式
  244. filter_expr = self._create_filter(filters)
  245. search_params = {
  246. "metric_type": "COSINE",
  247. "params": {"nprobe": 10}
  248. }
  249. results = collection.search(
  250. data=[query_embedding],
  251. anns_field="vector",
  252. param=search_params,
  253. limit=top_k,
  254. expr=filter_expr,
  255. output_fields=["text_content", "metadata"]
  256. )
  257. formatted_results = []
  258. for hits in results:
  259. for hit in hits:
  260. formatted_results.append({
  261. 'id': hit.id,
  262. 'text_content': hit.entity.get('text_content'),
  263. 'metadata': hit.entity.get('metadata'),
  264. 'distance': hit.distance,
  265. 'similarity': 1 - hit.distance
  266. })
  267. return formatted_results
  268. except Exception as e:
  269. _get_logger().error(f"Error searching with filter: {e}")
  270. return []
  271. def _create_filter(self, filters: Dict[str, Any]) -> str:
  272. """
  273. 创建过滤条件
  274. """
  275. # 构建过滤表达式
  276. filter_expr = ""
  277. if filters:
  278. conditions = []
  279. for key, value in filters.items():
  280. if isinstance(value, str):
  281. conditions.append(f'metadata["{key}"] == "{value}"')
  282. elif isinstance(value, (int, float)):
  283. conditions.append(f'metadata["{key}"] == {value}')
  284. else:
  285. conditions.append(f'metadata["{key}"] == "{json.dumps(value)}"')
  286. filter_expr = " and ".join(conditions)
  287. return filter_expr
  288. def create_hybrid_collection(self, collection_name: str, documents: List[Dict[str, Any]]):
  289. """
  290. 创建支持混合搜索的集合
  291. Args:
  292. collection_name: 集合名称
  293. documents: 文档列表,格式: [{'content': '...', 'metadata': {...}}, ...]
  294. """
  295. try:
  296. # 构建连接参数
  297. connection_args = {
  298. "uri": f"http://{self.host}:{self.port}",
  299. "user": self.user,
  300. "db_name": "lq_db"
  301. }
  302. if self.password:
  303. connection_args["password"] = self.password
  304. # 转换为 LangChain Document 格式
  305. langchain_docs = []
  306. for doc in documents:
  307. content = doc.get('content', '')
  308. metadata = doc.get('metadata', {})
  309. langchain_doc = Document(page_content=content, metadata=metadata)
  310. langchain_docs.append(langchain_doc)
  311. # 创建混合搜索向量存储
  312. vectorstore = Milvus.from_documents(
  313. documents=langchain_docs,
  314. embedding=self.emdmodel,
  315. builtin_function=BM25BuiltInFunction(),
  316. vector_field=["dense", "sparse"],
  317. connection_args=connection_args,
  318. collection_name=collection_name,
  319. consistency_level="Strong",
  320. drop_old=True,
  321. )
  322. _get_logger().info(f"Created hybrid collection: {collection_name} with {len(documents)} documents")
  323. return vectorstore
  324. except Exception as e:
  325. _get_logger().error(f"Error creating hybrid collection: {e}")
  326. _get_logger().info("Falling back to traditional vector search")
  327. return None
  328. def hybrid_search(self, param: Dict[str, Any], query_text: str,
  329. top_k: int = 5, ranker_type: str = "weighted",
  330. dense_weight: float = 0.7, sparse_weight: float = 0.3):
  331. """
  332. 混合搜索(参考 test_hybrid_v2.6.py 的实现)
  333. Args:
  334. param: 包含collection_name的参数字典
  335. query_text: 查询文本
  336. top_k: 返回结果数量
  337. ranker_type: 重排序类型 "weighted" 或 "rrf"
  338. dense_weight: 密集向量权重(当ranker_type="weighted"时使用)
  339. sparse_weight: 稀疏向量权重(当ranker_type="weighted"时使用)
  340. Returns:
  341. List[Dict]: 搜索结果列表
  342. """
  343. try:
  344. collection_name = param.get('collection_name')
  345. # 连接到现有集合
  346. connection_args = {
  347. "uri": f"http://{self.host}:{self.port}",
  348. "user": self.user,
  349. "db_name": "lq_db"
  350. }
  351. if self.password:
  352. connection_args["password"] = self.password
  353. vectorstore = Milvus(
  354. embedding_function=self.emdmodel,
  355. collection_name=collection_name,
  356. connection_args=connection_args,
  357. consistency_level="Strong",
  358. builtin_function=BM25BuiltInFunction(),
  359. vector_field=["dense", "sparse"]
  360. )
  361. # 执行混合搜索
  362. if ranker_type == "weighted":
  363. results = vectorstore.similarity_search(
  364. query=query_text,
  365. k=top_k,
  366. ranker_type="weighted",
  367. ranker_params={"weights": [dense_weight, sparse_weight]}
  368. )
  369. else: # rrf
  370. results = vectorstore.similarity_search(
  371. query=query_text,
  372. k=top_k,
  373. ranker_type="rrf",
  374. ranker_params={"k": 60}
  375. )
  376. # 格式化结果,保持与其他搜索方法一致
  377. formatted_results = []
  378. for doc in results:
  379. formatted_results.append({
  380. 'id': doc.metadata.get('pk', 0),
  381. 'text_content': doc.page_content,
  382. 'metadata': doc.metadata,
  383. 'distance': 0.0,
  384. 'similarity': 1.0
  385. })
  386. _get_logger().info(f"Hybrid search returned {len(formatted_results)} results")
  387. return formatted_results
  388. except Exception as e:
  389. _get_logger().error(f"Error in hybrid search: {e}")
  390. # 回退到传统的向量搜索
  391. _get_logger().info("Falling back to traditional vector search")
  392. return self.similarity_search(param, query_text, top_k=top_k)