milvus_service.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. """
  2. Milvus Service:业务层(直接用 manager.client 调 Milvus 原生方法)
  3. """
  4. from __future__ import annotations
  5. import sys
  6. import os
  7. # 添加src目录到Python路径
  8. sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../..'))
  9. sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../..'))
  10. import logging
  11. from typing import List, Dict, Any
  12. from datetime import datetime
  13. from app.base import get_milvus_manager, get_milvus_vectorstore, get_embedding_model
  14. logger = logging.getLogger(__name__)
  15. class MilvusService:
  16. def __init__(self):
  17. self.client = get_milvus_manager().client
  18. # 获取embedding model
  19. self.emdmodel = get_embedding_model()
  20. def create_collection(self, name: str, dimension: int = 768, description: str = "", fields: List[Dict] = None) -> None:
  21. """
  22. 创建 Milvus 集合
  23. :param fields: 自定义字段列表,每个元素为 {"name": "age", "type": "INT64", ...}
  24. """
  25. if self.client.has_collection(name):
  26. logger.info(f"Collection {name} already exists.")
  27. return
  28. # 如果有自定义字段,使用 schema 创建
  29. if fields:
  30. from pymilvus import MilvusClient, DataType
  31. # 1. 创建 Schema
  32. schema = MilvusClient.create_schema(
  33. auto_id=True,
  34. enable_dynamic_field=True,
  35. description=description
  36. )
  37. # 2. 添加必须的默认字段
  38. schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True, auto_id=True)
  39. schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dimension)
  40. # schema.add_field(field_name="sparse", datatype=DataType.SPARSE_FLOAT_VECTOR) # 如果需要混合检索,可能需要
  41. # 3. 添加用户自定义字段
  42. # 映射字符串类型到 pymilvus DataType
  43. type_map = {
  44. "BOOL": DataType.BOOL,
  45. "INT8": DataType.INT8,
  46. "INT16": DataType.INT16,
  47. "INT32": DataType.INT32,
  48. "INT64": DataType.INT64,
  49. "FLOAT": DataType.FLOAT,
  50. "DOUBLE": DataType.DOUBLE,
  51. "VARCHAR": DataType.VARCHAR,
  52. "JSON": DataType.JSON,
  53. "FLOAT_VECTOR": DataType.FLOAT_VECTOR
  54. }
  55. for f in fields:
  56. dtype = type_map.get(f.get("type", "").upper())
  57. if not dtype:
  58. continue # 忽略未知类型
  59. kwargs = {
  60. "field_name": f.get("name"),
  61. "datatype": dtype,
  62. "description": f.get("description", "")
  63. }
  64. if dtype == DataType.VARCHAR:
  65. kwargs["max_length"] = f.get("max_length", 65535)
  66. schema.add_field(**kwargs)
  67. # 4. 准备索引参数
  68. index_params = self.client.prepare_index_params()
  69. # 5. 添加向量索引
  70. index_params.add_index(
  71. field_name="vector",
  72. index_type="AUTOINDEX",
  73. metric_type="COSINE"
  74. )
  75. # 6. 为自定义标量字段添加索引 (可选,这里为所有标量字段添加倒排索引以加速过滤)
  76. for f in fields:
  77. # VARCHAR/INT/BOOL 等支持索引
  78. if f.get("type", "").upper() in ["VARCHAR", "INT64", "INT32", "BOOL"]:
  79. index_params.add_index(
  80. field_name=f.get("name"),
  81. index_type="INVERTED" # 标量字段倒排索引
  82. )
  83. # 7. 创建集合
  84. self.client.create_collection(
  85. collection_name=name,
  86. schema=schema,
  87. index_params=index_params
  88. )
  89. else:
  90. # 使用简化的 create_collection API
  91. self.client.create_collection(
  92. collection_name=name,
  93. dimension=dimension,
  94. description=description,
  95. auto_id=True, # 自动生成 ID
  96. id_type="int", # ID 类型
  97. metric_type="COSINE" # 默认使用余弦相似度
  98. )
  99. logger.info(f"Created collection {name} with dimension {dimension}")
  100. def drop_collection(self, name: str) -> None:
  101. """删除 Milvus 集合"""
  102. if self.client.has_collection(name):
  103. self.client.drop_collection(name)
  104. logger.info(f"Dropped collection {name}")
  105. def has_collection(self, name: str) -> bool:
  106. """检查集合是否存在"""
  107. return self.client.has_collection(name)
  108. def get_collection_details(self) -> List[Dict[str, Any]]:
  109. """
  110. 获取所有 Collections 详细信息
  111. """
  112. details: List[Dict[str, Any]] = []
  113. names = self.client.list_collections()
  114. for name in names:
  115. desc = self.client.describe_collection(collection_name=name)
  116. stats = self.client.get_collection_stats(collection_name=name)
  117. load_state = self.client.get_load_state(collection_name=name)
  118. # ===== 时间戳转换(按你指定写法,无封装)=====
  119. created_time = None
  120. updated_time = None
  121. if desc.get("created_timestamp") is not None:
  122. ts_int = int(desc["created_timestamp"])
  123. physical_ms = ts_int >> 18
  124. created_time = datetime.fromtimestamp(physical_ms / 1000).strftime("%Y-%m-%d %H:%M:%S")
  125. if desc.get("update_timestamp") is not None:
  126. ts_int = int(desc["update_timestamp"])
  127. physical_ms = ts_int >> 18
  128. updated_time = datetime.fromtimestamp(physical_ms / 1000).strftime("%Y-%m-%d %H:%M:%S")
  129. # ===== 数量:不保底(要求返回结构必须有 row_count)=====
  130. entity_count = stats["row_count"]
  131. # ===== 状态:不保底(要求返回结构必须有 state)=====
  132. status = load_state["state"]
  133. details.append(
  134. {
  135. "name": name,
  136. "status": status,
  137. "entity_count": entity_count,
  138. "description": desc.get("description", ""),
  139. "created_time": created_time,
  140. "updated_time": updated_time,
  141. }
  142. )
  143. logger.info(f"成功获取Collections详细信息,共{len(details)}个")
  144. return details
  145. def set_collection_state(self, name: str, action: str) -> Dict[str, Any]:
  146. """
  147. 改变指定 Collection 的加载状态。
  148. 参数:
  149. - name: 集合名称
  150. - action: 操作,取值 'load' 或 'release'
  151. 返回:
  152. - 包含集合名称和当前状态的字典,例如: {"name": name, "state": "Loaded"}
  153. """
  154. action_norm = (action or "").strip().lower()
  155. if action_norm not in {"load", "release"}:
  156. raise ValueError("action 必须为 'load' 或 'release'")
  157. # 执行加载/释放
  158. if action_norm == "load":
  159. self.client.load_collection(collection_name=name)
  160. else:
  161. self.client.release_collection(collection_name=name)
  162. # 返回最新状态
  163. load_state = self.client.get_load_state(collection_name=name)
  164. state = load_state.get("state") if isinstance(load_state, dict) else load_state
  165. result = {"name": name, "state": state, "action": action_norm}
  166. logger.info(f"集合 {name} 状态更新为 {state} (action={action_norm})")
  167. return result
  168. def delete_collection_if_empty(self, name: str) -> Dict[str, Any]:
  169. """仅当集合内容为空时删除集合,否则抛出异常"""
  170. stats = self.client.get_collection_stats(collection_name=name)
  171. row_count = stats.get("row_count") if isinstance(stats, dict) else None
  172. if row_count is None:
  173. raise ValueError("无法获取集合行数,禁止删除")
  174. if int(row_count) > 0:
  175. raise ValueError("集合内容不为空,不能删除")
  176. self.client.drop_collection(collection_name=name)
  177. logger.info(f"集合 {name} 已删除")
  178. return {"name": name, "deleted": True}
  179. def get_collection_detail(self, name: str) -> Dict[str, Any]:
  180. """获取单个集合的详细信息,包含schema、索引等所有desc字段"""
  181. desc = self.client.describe_collection(collection_name=name)
  182. stats = self.client.get_collection_stats(collection_name=name)
  183. load_state = self.client.get_load_state(collection_name=name)
  184. # 时间戳转换
  185. created_time = None
  186. updated_time = None
  187. if desc.get("created_timestamp") is not None:
  188. ts_int = int(desc["created_timestamp"])
  189. physical_ms = ts_int >> 18
  190. created_time = datetime.fromtimestamp(physical_ms / 1000).strftime("%Y-%m-%d %H:%M:%S")
  191. if desc.get("update_timestamp") is not None:
  192. ts_int = int(desc["update_timestamp"])
  193. physical_ms = ts_int >> 18
  194. updated_time = datetime.fromtimestamp(physical_ms / 1000).strftime("%Y-%m-%d %H:%M:%S")
  195. entity_count = stats.get("row_count", 0)
  196. status = load_state.get("state") if isinstance(load_state, dict) else load_state
  197. # 提取字段schema
  198. fields = []
  199. if "fields" in desc:
  200. for field in desc["fields"]:
  201. field_info = {
  202. "name": field.get("name"),
  203. "type": str(field.get("type")),
  204. "description": field.get("description", ""),
  205. "is_primary": field.get("is_primary", False),
  206. "auto_id": field.get("auto_id"),
  207. }
  208. # 向量维度
  209. if "params" in field and "dim" in field["params"]:
  210. field_info["dim"] = field["params"]["dim"]
  211. # 字符串长度
  212. if "params" in field and "max_length" in field["params"]:
  213. field_info["max_length"] = field["params"]["max_length"]
  214. # 其他params
  215. if "params" in field:
  216. field_info["params"] = field["params"]
  217. fields.append(field_info)
  218. # 提取索引信息
  219. indices = []
  220. # 尝试从 describe_collection 结果中获取 (兼容旧逻辑)
  221. if "indexes" in desc:
  222. for idx in desc["indexes"]:
  223. index_info = {
  224. "field_name": idx.get("field_name"),
  225. "index_name": idx.get("index_name"),
  226. "index_type": idx.get("index_type"),
  227. "metric_type": idx.get("metric_type"),
  228. "params": idx.get("params"),
  229. }
  230. indices.append(index_info)
  231. # 如果没有获取到索引信息,尝试主动查询 list_indexes
  232. if not indices:
  233. try:
  234. # 获取索引列表 (通常返回索引名称列表)
  235. index_names = self.client.list_indexes(collection_name=name)
  236. if index_names:
  237. for idx_name in index_names:
  238. try:
  239. # 获取索引详情
  240. idx_desc = self.client.describe_index(collection_name=name, index_name=idx_name)
  241. if idx_desc:
  242. indices.append({
  243. "field_name": idx_desc.get("field_name"),
  244. "index_name": idx_desc.get("index_name"),
  245. "index_type": idx_desc.get("index_type"),
  246. "metric_type": idx_desc.get("metric_type"),
  247. "params": idx_desc.get("params"),
  248. })
  249. except Exception:
  250. continue
  251. except Exception as e:
  252. logger.warning(f"Failed to list/describe indexes for {name}: {e}")
  253. detail = {
  254. "name": name,
  255. "description": desc.get("description", ""),
  256. "status": status,
  257. "entity_count": entity_count,
  258. "created_time": created_time,
  259. "updated_time": updated_time,
  260. "fields": fields,
  261. "enable_dynamic_field": desc.get("enable_dynamic_field", False),
  262. "consistency_level": desc.get("consistency_level"),
  263. "num_shards": desc.get("num_shards"),
  264. "num_partitions": desc.get("num_partitions"),
  265. "indices": indices,
  266. "properties": desc.get("properties"),
  267. "aliases": desc.get("aliases", []),
  268. }
  269. logger.info(f"成功获取集合 {name} 的详细信息")
  270. return detail
  271. def update_collection_description(self, name: str, description: str) -> Dict[str, Any]:
  272. """使用 alter_collection_properties 更新集合描述"""
  273. description = description or ""
  274. # 1. 更新集合 description(唯一修改点)
  275. self.client.alter_collection_properties(
  276. collection_name=name,
  277. properties={"collection.description": description},
  278. )
  279. # 2. 重新获取集合信息
  280. desc = self.client.describe_collection(collection_name=name)
  281. print(desc)
  282. stats = self.client.get_collection_stats(collection_name=name)
  283. load_state = self.client.get_load_state(collection_name=name)
  284. # 3. 时间戳转换(Milvus TSO -> 物理时间)
  285. def ts_to_str(ts):
  286. if ts is None:
  287. return None
  288. ts_int = int(ts)
  289. physical_ms = ts_int >> 18
  290. return datetime.fromtimestamp(physical_ms / 1000).strftime("%Y-%m-%d %H:%M:%S")
  291. created_time = ts_to_str(desc.get("created_timestamp"))
  292. updated_time = ts_to_str(desc.get("update_timestamp"))
  293. entity_count = stats.get("row_count") if isinstance(stats, dict) else None
  294. status = load_state.get("state") if isinstance(load_state, dict) else load_state
  295. return {
  296. "name": name,
  297. "status": status,
  298. "entity_count": entity_count,
  299. "description": desc.get("description", ""),
  300. "created_time": created_time,
  301. "updated_time": updated_time,
  302. }
  303. def hybrid_search(self, collection_name: str, query_text: str,
  304. top_k: int = 3, ranker_type: str = "weighted",
  305. dense_weight: float = 0.7, sparse_weight: float = 0.3):
  306. """
  307. 混合搜索(参考 test_hybrid_v2.6.py 的实现)
  308. Args:
  309. param: 包含collection_name的参数字典
  310. query_text: 查询文本
  311. top_k: 返回结果数量
  312. ranker_type: 重排序类型 "weighted" 或 "rrf"
  313. dense_weight: 密集向量权重(当ranker_type="weighted"时使用)
  314. sparse_weight: 稀疏向量权重(当ranker_type="weighted"时使用)
  315. Returns:
  316. List[Dict]: 搜索结果列表
  317. """
  318. try:
  319. collection_name = collection_name
  320. # 获取 vectorstore 实例(包含 Milvus 和 BM25BuiltInFunction)
  321. vectorstore = get_milvus_vectorstore(
  322. collection_name=collection_name,
  323. consistency_level="Strong"
  324. )
  325. # 执行混合搜索 (完全按照 test_hybrid_v2.6.py 的逻辑)
  326. if ranker_type == "weighted":
  327. results = vectorstore.similarity_search(
  328. query=query_text,
  329. k=top_k,
  330. ranker_type="weighted",
  331. ranker_params={"weights": [dense_weight, sparse_weight]}
  332. )
  333. else: # rrf
  334. results = vectorstore.similarity_search(
  335. query=query_text,
  336. k=top_k,
  337. ranker_type="rrf",
  338. ranker_params={"k": 60}
  339. )
  340. # 格式化结果,保持与其他搜索方法一致
  341. formatted_results = []
  342. for doc in results:
  343. formatted_results.append({
  344. 'id': doc.metadata.get('pk', 0),
  345. 'text_content': doc.page_content,
  346. 'metadata': doc.metadata,
  347. 'distance': 0.0,
  348. 'similarity': 1.0
  349. })
  350. logger.info(f"Hybrid search returned {len(formatted_results)} results")
  351. return formatted_results
  352. except Exception as e:
  353. logger.error(f"Error in hybrid search: {e}")
  354. # 回退到传统的向量搜索
  355. logger.info("Falling back to traditional vector search")
  356. # 可选:单例
  357. milvus_service = MilvusService()
  358. if __name__ == "__main__":
  359. # 推荐这样跑:
  360. # uv run python -m src.app.services.milvus_service
  361. import json
  362. service = MilvusService()
  363. # 测试混合搜索 hybrid_search
  364. print("=" * 50)
  365. print("测试混合检索 (Hybrid Search)")
  366. print("=" * 50)
  367. try:
  368. # 示例参数,需要根据实际情况修改
  369. collection_name = "first_bfp_collection_status"
  370. query_text = "《公路水运工程临时用电技术规程》(JTT1499-2024)状态为现行" # 修改为实际查询内容
  371. # 测试 weighted 模式
  372. print("\n1. 测试 Weighted 重排序模式:")
  373. print(f" 集合: {collection_name}")
  374. print(f" 查询: {query_text}")
  375. print(f" 密集权重: 0.7, 稀疏权重: 0.3")
  376. results_weighted = service.hybrid_search(
  377. collection_name=collection_name,
  378. query_text=query_text,
  379. top_k=5,
  380. ranker_type="weighted",
  381. dense_weight=0.7,
  382. sparse_weight=0.3
  383. )
  384. print(f"\n 结果数量: {len(results_weighted)}")
  385. for i, result in enumerate(results_weighted, 1):
  386. print(f" [{i}] ID: {result.get('id')}, Text: {result.get('text_content')[:50]}...")
  387. # 测试 RRF 模式
  388. print("\n2. 测试 RRF (Reciprocal Rank Fusion) 重排序模式:")
  389. print(f" 集合: {collection_name}")
  390. print(f" 查询: {query_text}")
  391. results_rrf = service.hybrid_search(
  392. collection_name=collection_name,
  393. query_text=query_text,
  394. top_k=5,
  395. ranker_type="rrf"
  396. )
  397. print(f"\n 结果数量: {len(results_rrf)}")
  398. for i, result in enumerate(results_rrf, 1):
  399. print(f" [{i}] ID: {result.get('id')}, Text: {result.get('text_content')[:50]}...")
  400. print("\n✓ 混合检索测试完成")
  401. except Exception as e:
  402. print(f"\n✗ 混合检索测试失败: {e}")
  403. import traceback
  404. traceback.print_exc()
  405. # 也可以查看集合详情
  406. print("\n" + "=" * 50)
  407. print("获取所有集合信息:")
  408. print("=" * 50)
  409. data = service.get_collection_details()
  410. for item in data:
  411. print(json.dumps(item, ensure_ascii=False, indent=2))