milvus_service.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. """
  2. Milvus Service:业务层(直接用 manager.client 调 Milvus 原生方法)
  3. """
  4. from __future__ import annotations
  5. import sys
  6. import os
  7. # 添加src目录到Python路径
  8. sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../..'))
  9. sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../..'))
  10. import logging
  11. from typing import List, Dict, Any
  12. from datetime import datetime
  13. from app.base import get_milvus_manager, get_milvus_vectorstore, get_embedding_model
  14. logger = logging.getLogger(__name__)
  15. class MilvusService:
  16. def __init__(self):
  17. self.client = get_milvus_manager().client
  18. # 获取embedding model
  19. self.emdmodel = get_embedding_model()
  20. # 默认向量维度 (Qwen3-Embedding-8B default)
  21. self.DENSE_DIM = 4096
  22. def create_collection(self, name: str, dimension: int = None, description: str = "", fields: List[Dict] = None) -> None:
  23. """
  24. 创建 Milvus 集合
  25. :param dimension: 向量维度,如果为None则使用默认值
  26. :param fields: 自定义字段列表,每个元素为 {"name": "age", "type": "INT64", ...}
  27. """
  28. # 使用默认维度
  29. if dimension is None:
  30. dimension = self.DENSE_DIM
  31. if self.client.has_collection(name):
  32. logger.info(f"Collection {name} already exists.")
  33. return
  34. # 如果有自定义字段,使用 schema 创建
  35. if fields:
  36. from pymilvus import MilvusClient, DataType, Function, FunctionType
  37. # 1. 创建 Schema
  38. schema = MilvusClient.create_schema(
  39. auto_id=True,
  40. enable_dynamic_field=True,
  41. description=description
  42. )
  43. # 检查字段中是否定义了主键
  44. has_primary = any(f.get("is_primary") for f in fields)
  45. if not has_primary:
  46. # 如果没有定义主键,添加默认主键
  47. schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True, auto_id=True)
  48. # 检查是否有默认向量列,如果没有则添加 (兼容旧逻辑,但如果fields里有vector则不添加)
  49. has_vector = any(f.get("type") == "FLOAT_VECTOR" for f in fields)
  50. if not has_vector:
  51. schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dimension)
  52. # 3. 添加用户自定义字段
  53. type_map = {
  54. "BOOL": DataType.BOOL,
  55. "INT8": DataType.INT8,
  56. "INT16": DataType.INT16,
  57. "INT32": DataType.INT32,
  58. "INT64": DataType.INT64,
  59. "FLOAT": DataType.FLOAT,
  60. "DOUBLE": DataType.DOUBLE,
  61. "VARCHAR": DataType.VARCHAR,
  62. "JSON": DataType.JSON,
  63. "FLOAT_VECTOR": DataType.FLOAT_VECTOR,
  64. "SPARSE_FLOAT_VECTOR": DataType.SPARSE_FLOAT_VECTOR,
  65. "BM25": DataType.SPARSE_FLOAT_VECTOR # BM25 特殊处理,映射为稀疏向量
  66. }
  67. bm25_field = None
  68. text_field_name = "text" # 默认文本字段名
  69. for f in fields:
  70. field_type_str = f.get("type", "").upper()
  71. dtype = type_map.get(field_type_str)
  72. if not dtype:
  73. continue
  74. # 记录文本字段名,供BM25使用
  75. if f.get("name") in ["text", "content", "chunk"]:
  76. text_field_name = f.get("name")
  77. kwargs = {
  78. "field_name": f.get("name"),
  79. "datatype": dtype,
  80. "description": f.get("description", "")
  81. }
  82. if f.get("is_primary"):
  83. kwargs["is_primary"] = True
  84. kwargs["auto_id"] = True # 假设主键都是自增
  85. if dtype == DataType.VARCHAR:
  86. kwargs["max_length"] = f.get("max_length", 65535)
  87. # 关键修复:如果要被 BM25 引用,必须启用 analyzer
  88. if f.get("name") in ["text", "content", "chunk"]:
  89. kwargs["enable_analyzer"] = True
  90. if dtype == DataType.FLOAT_VECTOR:
  91. kwargs["dim"] = dimension # 使用传入的 dimension
  92. schema.add_field(**kwargs)
  93. # 如果是 BM25 类型,记录下来以便后续添加 Function
  94. if field_type_str == "BM25":
  95. bm25_field = f.get("name")
  96. # 处理 BM25 Function
  97. if bm25_field:
  98. try:
  99. schema.add_function(Function(
  100. name="bm25_fn",
  101. input_field_names=[text_field_name],
  102. output_field_names=[bm25_field],
  103. function_type=FunctionType.BM25
  104. ))
  105. logger.info(f"Added BM25 function mapping {text_field_name} -> {bm25_field}")
  106. except Exception as e:
  107. logger.error(f"Failed to add BM25 function: {e}")
  108. # 4. 准备索引参数
  109. index_params = self.client.prepare_index_params()
  110. # 5. 为所有向量字段添加索引
  111. for f in fields:
  112. ftype = f.get("type", "").upper()
  113. if ftype == "FLOAT_VECTOR":
  114. index_params.add_index(
  115. field_name=f.get("name"),
  116. index_type="AUTOINDEX",
  117. metric_type="COSINE"
  118. )
  119. elif ftype == "BM25" or ftype == "SPARSE_FLOAT_VECTOR":
  120. index_params.add_index(
  121. field_name=f.get("name"),
  122. index_type="SPARSE_INVERTED_INDEX", # 稀疏向量索引
  123. metric_type="BM25"
  124. )
  125. # 6. 为自定义标量字段添加索引
  126. for f in fields:
  127. if f.get("type", "").upper() in ["VARCHAR", "INT64", "INT32", "BOOL"] and not f.get("is_primary"):
  128. # 排除主键,主键自动索引
  129. index_params.add_index(
  130. field_name=f.get("name"),
  131. index_type="INVERTED"
  132. )
  133. # 7. 创建集合
  134. self.client.create_collection(
  135. collection_name=name,
  136. schema=schema,
  137. index_params=index_params
  138. )
  139. else:
  140. # 使用简化的 create_collection API
  141. self.client.create_collection(
  142. collection_name=name,
  143. dimension=dimension,
  144. description=description,
  145. auto_id=True, # 自动生成 ID
  146. id_type="int", # ID 类型
  147. metric_type="COSINE" # 默认使用余弦相似度
  148. )
  149. logger.info(f"Created collection {name} with dimension {dimension}")
  150. def drop_collection(self, name: str) -> None:
  151. """删除 Milvus 集合"""
  152. if self.client.has_collection(name):
  153. self.client.drop_collection(name)
  154. logger.info(f"Dropped collection {name}")
  155. def has_collection(self, name: str) -> bool:
  156. """检查集合是否存在"""
  157. return self.client.has_collection(name)
  158. def get_collection_details(self) -> List[Dict[str, Any]]:
  159. """
  160. 获取所有 Collections 详细信息
  161. """
  162. details: List[Dict[str, Any]] = []
  163. names = self.client.list_collections()
  164. for name in names:
  165. desc = self.client.describe_collection(collection_name=name)
  166. stats = self.client.get_collection_stats(collection_name=name)
  167. load_state = self.client.get_load_state(collection_name=name)
  168. # ===== 时间戳转换(按你指定写法,无封装)=====
  169. created_time = None
  170. updated_time = None
  171. if desc.get("created_timestamp") is not None:
  172. ts_int = int(desc["created_timestamp"])
  173. physical_ms = ts_int >> 18
  174. created_time = datetime.fromtimestamp(physical_ms / 1000).strftime("%Y-%m-%d %H:%M:%S")
  175. if desc.get("update_timestamp") is not None:
  176. ts_int = int(desc["update_timestamp"])
  177. physical_ms = ts_int >> 18
  178. updated_time = datetime.fromtimestamp(physical_ms / 1000).strftime("%Y-%m-%d %H:%M:%S")
  179. # ===== 数量:不保底(要求返回结构必须有 row_count)=====
  180. entity_count = stats["row_count"]
  181. # ===== 状态:不保底(要求返回结构必须有 state)=====
  182. status = load_state["state"]
  183. details.append(
  184. {
  185. "name": name,
  186. "status": status,
  187. "entity_count": entity_count,
  188. "description": desc.get("description", ""),
  189. "created_time": created_time,
  190. "updated_time": updated_time,
  191. }
  192. )
  193. logger.info(f"成功获取Collections详细信息,共{len(details)}个")
  194. return details
  195. def set_collection_state(self, name: str, action: str) -> Dict[str, Any]:
  196. """
  197. 改变指定 Collection 的加载状态。
  198. 参数:
  199. - name: 集合名称
  200. - action: 操作,取值 'load' 或 'release'
  201. 返回:
  202. - 包含集合名称和当前状态的字典,例如: {"name": name, "state": "Loaded"}
  203. """
  204. action_norm = (action or "").strip().lower()
  205. if action_norm not in {"load", "release"}:
  206. raise ValueError("action 必须为 'load' 或 'release'")
  207. # 执行加载/释放
  208. if action_norm == "load":
  209. self.client.load_collection(collection_name=name)
  210. else:
  211. self.client.release_collection(collection_name=name)
  212. # 返回最新状态
  213. load_state = self.client.get_load_state(collection_name=name)
  214. state = load_state.get("state") if isinstance(load_state, dict) else load_state
  215. result = {"name": name, "state": state, "action": action_norm}
  216. logger.info(f"集合 {name} 状态更新为 {state} (action={action_norm})")
  217. return result
  218. def delete_collection_if_empty(self, name: str) -> Dict[str, Any]:
  219. """仅当集合内容为空时删除集合,否则抛出异常"""
  220. stats = self.client.get_collection_stats(collection_name=name)
  221. row_count = stats.get("row_count") if isinstance(stats, dict) else None
  222. if row_count is None:
  223. raise ValueError("无法获取集合行数,禁止删除")
  224. if int(row_count) > 0:
  225. raise ValueError("集合内容不为空,不能删除")
  226. self.client.drop_collection(collection_name=name)
  227. logger.info(f"集合 {name} 已删除")
  228. return {"name": name, "deleted": True}
  229. def get_collection_detail(self, name: str) -> Dict[str, Any]:
  230. """获取单个集合的详细信息,包含schema、索引等所有desc字段"""
  231. desc = self.client.describe_collection(collection_name=name)
  232. stats = self.client.get_collection_stats(collection_name=name)
  233. load_state = self.client.get_load_state(collection_name=name)
  234. # 时间戳转换
  235. created_time = None
  236. updated_time = None
  237. if desc.get("created_timestamp") is not None:
  238. ts_int = int(desc["created_timestamp"])
  239. physical_ms = ts_int >> 18
  240. created_time = datetime.fromtimestamp(physical_ms / 1000).strftime("%Y-%m-%d %H:%M:%S")
  241. if desc.get("update_timestamp") is not None:
  242. ts_int = int(desc["update_timestamp"])
  243. physical_ms = ts_int >> 18
  244. updated_time = datetime.fromtimestamp(physical_ms / 1000).strftime("%Y-%m-%d %H:%M:%S")
  245. entity_count = stats.get("row_count", 0)
  246. status = load_state.get("state") if isinstance(load_state, dict) else load_state
  247. # 提取字段schema
  248. fields = []
  249. if "fields" in desc:
  250. for field in desc["fields"]:
  251. field_info = {
  252. "name": field.get("name"),
  253. "type": str(field.get("type")),
  254. "description": field.get("description", ""),
  255. "is_primary": field.get("is_primary", False),
  256. "auto_id": field.get("auto_id"),
  257. }
  258. # 向量维度
  259. if "params" in field and "dim" in field["params"]:
  260. field_info["dim"] = field["params"]["dim"]
  261. # 字符串长度
  262. if "params" in field and "max_length" in field["params"]:
  263. field_info["max_length"] = field["params"]["max_length"]
  264. # 其他params
  265. if "params" in field:
  266. field_info["params"] = field["params"]
  267. fields.append(field_info)
  268. # 提取索引信息
  269. indices = []
  270. # 尝试从 describe_collection 结果中获取 (兼容旧逻辑)
  271. if "indexes" in desc:
  272. for idx in desc["indexes"]:
  273. index_info = {
  274. "field_name": idx.get("field_name"),
  275. "index_name": idx.get("index_name"),
  276. "index_type": idx.get("index_type"),
  277. "metric_type": idx.get("metric_type"),
  278. "params": idx.get("params"),
  279. }
  280. indices.append(index_info)
  281. # 如果没有获取到索引信息,尝试主动查询 list_indexes
  282. if not indices:
  283. try:
  284. # 获取索引列表 (通常返回索引名称列表)
  285. index_names = self.client.list_indexes(collection_name=name)
  286. if index_names:
  287. for idx_name in index_names:
  288. try:
  289. # 获取索引详情
  290. idx_desc = self.client.describe_index(collection_name=name, index_name=idx_name)
  291. if idx_desc:
  292. indices.append({
  293. "field_name": idx_desc.get("field_name"),
  294. "index_name": idx_desc.get("index_name"),
  295. "index_type": idx_desc.get("index_type"),
  296. "metric_type": idx_desc.get("metric_type"),
  297. "params": idx_desc.get("params"),
  298. })
  299. except Exception:
  300. continue
  301. except Exception as e:
  302. logger.warning(f"Failed to list/describe indexes for {name}: {e}")
  303. detail = {
  304. "name": name,
  305. "description": desc.get("description", ""),
  306. "status": status,
  307. "entity_count": entity_count,
  308. "created_time": created_time,
  309. "updated_time": updated_time,
  310. "fields": fields,
  311. "enable_dynamic_field": desc.get("enable_dynamic_field", False),
  312. "consistency_level": desc.get("consistency_level"),
  313. "num_shards": desc.get("num_shards"),
  314. "num_partitions": desc.get("num_partitions"),
  315. "indices": indices,
  316. "properties": desc.get("properties"),
  317. "aliases": desc.get("aliases", []),
  318. }
  319. logger.info(f"成功获取集合 {name} 的详细信息")
  320. return detail
  321. def update_collection_description(self, name: str, description: str) -> Dict[str, Any]:
  322. """使用 alter_collection_properties 更新集合描述"""
  323. description = description or ""
  324. # 1. 更新集合 description(唯一修改点)
  325. self.client.alter_collection_properties(
  326. collection_name=name,
  327. properties={"collection.description": description},
  328. )
  329. # 2. 重新获取集合信息
  330. desc = self.client.describe_collection(collection_name=name)
  331. print(desc)
  332. stats = self.client.get_collection_stats(collection_name=name)
  333. load_state = self.client.get_load_state(collection_name=name)
  334. # 3. 时间戳转换(Milvus TSO -> 物理时间)
  335. def ts_to_str(ts):
  336. if ts is None:
  337. return None
  338. ts_int = int(ts)
  339. physical_ms = ts_int >> 18
  340. return datetime.fromtimestamp(physical_ms / 1000).strftime("%Y-%m-%d %H:%M:%S")
  341. created_time = ts_to_str(desc.get("created_timestamp"))
  342. updated_time = ts_to_str(desc.get("update_timestamp"))
  343. entity_count = stats.get("row_count") if isinstance(stats, dict) else None
  344. status = load_state.get("state") if isinstance(load_state, dict) else load_state
  345. return {
  346. "name": name,
  347. "status": status,
  348. "entity_count": entity_count,
  349. "description": desc.get("description", ""),
  350. "created_time": created_time,
  351. "updated_time": updated_time,
  352. }
  353. def hybrid_search(self, collection_name: str, query_text: str,
  354. top_k: int = 3, ranker_type: str = "weighted",
  355. dense_weight: float = 0.7, sparse_weight: float = 0.3,
  356. expr: str = None):
  357. """
  358. 混合搜索(参考 test_hybrid_v2.6.py 的实现)
  359. Args:
  360. param: 包含collection_name的参数字典
  361. query_text: 查询文本
  362. top_k: 返回结果数量
  363. ranker_type: 重排序类型 "weighted" 或 "rrf"
  364. dense_weight: 密集向量权重(当ranker_type="weighted"时使用)
  365. sparse_weight: 稀疏向量权重(当ranker_type="weighted"时使用)
  366. expr: 过滤表达式 (Metadata Filtering)
  367. Returns:
  368. List[Dict]: 搜索结果列表
  369. """
  370. try:
  371. collection_name = collection_name
  372. # 确保集合已加载
  373. self.client.load_collection(collection_name)
  374. # 获取 vectorstore 实例(包含 Milvus 和 BM25BuiltInFunction)
  375. vectorstore = get_milvus_vectorstore(
  376. collection_name=collection_name,
  377. consistency_level="Strong"
  378. )
  379. # 执行混合搜索 (完全按照 test_hybrid_v2.6.py 的逻辑)
  380. # 注意:LangChain Milvus vectorstore 的 similarity_search 支持 expr 参数用于过滤
  381. if ranker_type == "weighted":
  382. results = vectorstore.similarity_search(
  383. query=query_text,
  384. k=top_k,
  385. expr=expr,
  386. ranker_type="weighted",
  387. ranker_params={"weights": [dense_weight, sparse_weight]}
  388. )
  389. else: # rrf
  390. results = vectorstore.similarity_search(
  391. query=query_text,
  392. k=top_k,
  393. expr=expr,
  394. ranker_type="rrf",
  395. ranker_params={"k": 60}
  396. )
  397. # 格式化结果,保持与其他搜索方法一致
  398. formatted_results = []
  399. for doc in results:
  400. formatted_results.append({
  401. 'id': doc.metadata.get('pk', 0),
  402. 'text_content': doc.page_content,
  403. 'metadata': doc.metadata,
  404. 'distance': 0.0,
  405. 'similarity': 1.0
  406. })
  407. logger.info(f"Hybrid search returned {len(formatted_results)} results")
  408. return formatted_results
  409. except Exception as e:
  410. logger.error(f"Error in hybrid search: {e}")
  411. # 回退到传统的向量搜索
  412. logger.info("Falling back to traditional vector search")
  413. # 可选:单例
  414. milvus_service = MilvusService()
  415. if __name__ == "__main__":
  416. # 推荐这样跑:
  417. # uv run python -m src.app.services.milvus_service
  418. import json
  419. service = MilvusService()
  420. # 测试混合搜索 hybrid_search
  421. print("=" * 50)
  422. print("测试混合检索 (Hybrid Search)")
  423. print("=" * 50)
  424. try:
  425. # 示例参数,需要根据实际情况修改
  426. collection_name = "first_bfp_collection_status"
  427. query_text = "《公路水运工程临时用电技术规程》(JTT1499-2024)状态为现行" # 修改为实际查询内容
  428. # 测试 weighted 模式
  429. print("\n1. 测试 Weighted 重排序模式:")
  430. print(f" 集合: {collection_name}")
  431. print(f" 查询: {query_text}")
  432. print(f" 密集权重: 0.7, 稀疏权重: 0.3")
  433. results_weighted = service.hybrid_search(
  434. collection_name=collection_name,
  435. query_text=query_text,
  436. top_k=5,
  437. ranker_type="weighted",
  438. dense_weight=0.7,
  439. sparse_weight=0.3
  440. )
  441. print(f"\n 结果数量: {len(results_weighted)}")
  442. for i, result in enumerate(results_weighted, 1):
  443. print(f" [{i}] ID: {result.get('id')}, Text: {result.get('text_content')[:50]}...")
  444. # 测试 RRF 模式
  445. print("\n2. 测试 RRF (Reciprocal Rank Fusion) 重排序模式:")
  446. print(f" 集合: {collection_name}")
  447. print(f" 查询: {query_text}")
  448. results_rrf = service.hybrid_search(
  449. collection_name=collection_name,
  450. query_text=query_text,
  451. top_k=5,
  452. ranker_type="rrf"
  453. )
  454. print(f"\n 结果数量: {len(results_rrf)}")
  455. for i, result in enumerate(results_rrf, 1):
  456. print(f" [{i}] ID: {result.get('id')}, Text: {result.get('text_content')[:50]}...")
  457. print("\n✓ 混合检索测试完成")
  458. except Exception as e:
  459. print(f"\n✗ 混合检索测试失败: {e}")
  460. import traceback
  461. traceback.print_exc()
  462. # 也可以查看集合详情
  463. print("\n" + "=" * 50)
  464. print("获取所有集合信息:")
  465. print("=" * 50)
  466. data = service.get_collection_details()
  467. for item in data:
  468. print(json.dumps(item, ensure_ascii=False, indent=2))