""" Milvus Service:业务层(直接用 manager.client 调 Milvus 原生方法) """ from __future__ import annotations import sys import os # 添加src目录到Python路径 sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../..')) sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../..')) import logging from typing import List, Dict, Any from datetime import datetime from app.base import get_milvus_manager, get_milvus_vectorstore, get_embedding_model logger = logging.getLogger(__name__) class MilvusService: def __init__(self): self.client = get_milvus_manager().client # 获取embedding model self.emdmodel = get_embedding_model() def create_collection(self, name: str, dimension: int = 768, description: str = "", fields: List[Dict] = None) -> None: """ 创建 Milvus 集合 :param fields: 自定义字段列表,每个元素为 {"name": "age", "type": "INT64", ...} """ if self.client.has_collection(name): logger.info(f"Collection {name} already exists.") return # 如果有自定义字段,使用 schema 创建 if fields: from pymilvus import MilvusClient, DataType # 1. 创建 Schema schema = MilvusClient.create_schema( auto_id=True, enable_dynamic_field=True, description=description ) # 2. 添加必须的默认字段 schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True, auto_id=True) schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dimension) # schema.add_field(field_name="sparse", datatype=DataType.SPARSE_FLOAT_VECTOR) # 如果需要混合检索,可能需要 # 3. 添加用户自定义字段 # 映射字符串类型到 pymilvus DataType type_map = { "BOOL": DataType.BOOL, "INT8": DataType.INT8, "INT16": DataType.INT16, "INT32": DataType.INT32, "INT64": DataType.INT64, "FLOAT": DataType.FLOAT, "DOUBLE": DataType.DOUBLE, "VARCHAR": DataType.VARCHAR, "JSON": DataType.JSON, "FLOAT_VECTOR": DataType.FLOAT_VECTOR } for f in fields: dtype = type_map.get(f.get("type", "").upper()) if not dtype: continue # 忽略未知类型 kwargs = { "field_name": f.get("name"), "datatype": dtype, "description": f.get("description", "") } if dtype == DataType.VARCHAR: kwargs["max_length"] = f.get("max_length", 65535) schema.add_field(**kwargs) # 4. 准备索引参数 index_params = self.client.prepare_index_params() # 5. 添加向量索引 index_params.add_index( field_name="vector", index_type="AUTOINDEX", metric_type="COSINE" ) # 6. 为自定义标量字段添加索引 (可选,这里为所有标量字段添加倒排索引以加速过滤) for f in fields: # VARCHAR/INT/BOOL 等支持索引 if f.get("type", "").upper() in ["VARCHAR", "INT64", "INT32", "BOOL"]: index_params.add_index( field_name=f.get("name"), index_type="INVERTED" # 标量字段倒排索引 ) # 7. 创建集合 self.client.create_collection( collection_name=name, schema=schema, index_params=index_params ) else: # 使用简化的 create_collection API self.client.create_collection( collection_name=name, dimension=dimension, description=description, auto_id=True, # 自动生成 ID id_type="int", # ID 类型 metric_type="COSINE" # 默认使用余弦相似度 ) logger.info(f"Created collection {name} with dimension {dimension}") def drop_collection(self, name: str) -> None: """删除 Milvus 集合""" if self.client.has_collection(name): self.client.drop_collection(name) logger.info(f"Dropped collection {name}") def has_collection(self, name: str) -> bool: """检查集合是否存在""" return self.client.has_collection(name) def get_collection_details(self) -> List[Dict[str, Any]]: """ 获取所有 Collections 详细信息 """ details: List[Dict[str, Any]] = [] names = self.client.list_collections() for name in names: desc = self.client.describe_collection(collection_name=name) stats = self.client.get_collection_stats(collection_name=name) load_state = self.client.get_load_state(collection_name=name) # ===== 时间戳转换(按你指定写法,无封装)===== created_time = None updated_time = None if desc.get("created_timestamp") is not None: ts_int = int(desc["created_timestamp"]) physical_ms = ts_int >> 18 created_time = datetime.fromtimestamp(physical_ms / 1000).strftime("%Y-%m-%d %H:%M:%S") if desc.get("update_timestamp") is not None: ts_int = int(desc["update_timestamp"]) physical_ms = ts_int >> 18 updated_time = datetime.fromtimestamp(physical_ms / 1000).strftime("%Y-%m-%d %H:%M:%S") # ===== 数量:不保底(要求返回结构必须有 row_count)===== entity_count = stats["row_count"] # ===== 状态:不保底(要求返回结构必须有 state)===== status = load_state["state"] details.append( { "name": name, "status": status, "entity_count": entity_count, "description": desc.get("description", ""), "created_time": created_time, "updated_time": updated_time, } ) logger.info(f"成功获取Collections详细信息,共{len(details)}个") return details def set_collection_state(self, name: str, action: str) -> Dict[str, Any]: """ 改变指定 Collection 的加载状态。 参数: - name: 集合名称 - action: 操作,取值 'load' 或 'release' 返回: - 包含集合名称和当前状态的字典,例如: {"name": name, "state": "Loaded"} """ action_norm = (action or "").strip().lower() if action_norm not in {"load", "release"}: raise ValueError("action 必须为 'load' 或 'release'") # 执行加载/释放 if action_norm == "load": self.client.load_collection(collection_name=name) else: self.client.release_collection(collection_name=name) # 返回最新状态 load_state = self.client.get_load_state(collection_name=name) state = load_state.get("state") if isinstance(load_state, dict) else load_state result = {"name": name, "state": state, "action": action_norm} logger.info(f"集合 {name} 状态更新为 {state} (action={action_norm})") return result def delete_collection_if_empty(self, name: str) -> Dict[str, Any]: """仅当集合内容为空时删除集合,否则抛出异常""" stats = self.client.get_collection_stats(collection_name=name) row_count = stats.get("row_count") if isinstance(stats, dict) else None if row_count is None: raise ValueError("无法获取集合行数,禁止删除") if int(row_count) > 0: raise ValueError("集合内容不为空,不能删除") self.client.drop_collection(collection_name=name) logger.info(f"集合 {name} 已删除") return {"name": name, "deleted": True} def get_collection_detail(self, name: str) -> Dict[str, Any]: """获取单个集合的详细信息,包含schema、索引等所有desc字段""" desc = self.client.describe_collection(collection_name=name) stats = self.client.get_collection_stats(collection_name=name) load_state = self.client.get_load_state(collection_name=name) # 时间戳转换 created_time = None updated_time = None if desc.get("created_timestamp") is not None: ts_int = int(desc["created_timestamp"]) physical_ms = ts_int >> 18 created_time = datetime.fromtimestamp(physical_ms / 1000).strftime("%Y-%m-%d %H:%M:%S") if desc.get("update_timestamp") is not None: ts_int = int(desc["update_timestamp"]) physical_ms = ts_int >> 18 updated_time = datetime.fromtimestamp(physical_ms / 1000).strftime("%Y-%m-%d %H:%M:%S") entity_count = stats.get("row_count", 0) status = load_state.get("state") if isinstance(load_state, dict) else load_state # 提取字段schema fields = [] if "fields" in desc: for field in desc["fields"]: field_info = { "name": field.get("name"), "type": str(field.get("type")), "description": field.get("description", ""), "is_primary": field.get("is_primary", False), "auto_id": field.get("auto_id"), } # 向量维度 if "params" in field and "dim" in field["params"]: field_info["dim"] = field["params"]["dim"] # 字符串长度 if "params" in field and "max_length" in field["params"]: field_info["max_length"] = field["params"]["max_length"] # 其他params if "params" in field: field_info["params"] = field["params"] fields.append(field_info) # 提取索引信息 indices = [] # 尝试从 describe_collection 结果中获取 (兼容旧逻辑) if "indexes" in desc: for idx in desc["indexes"]: index_info = { "field_name": idx.get("field_name"), "index_name": idx.get("index_name"), "index_type": idx.get("index_type"), "metric_type": idx.get("metric_type"), "params": idx.get("params"), } indices.append(index_info) # 如果没有获取到索引信息,尝试主动查询 list_indexes if not indices: try: # 获取索引列表 (通常返回索引名称列表) index_names = self.client.list_indexes(collection_name=name) if index_names: for idx_name in index_names: try: # 获取索引详情 idx_desc = self.client.describe_index(collection_name=name, index_name=idx_name) if idx_desc: indices.append({ "field_name": idx_desc.get("field_name"), "index_name": idx_desc.get("index_name"), "index_type": idx_desc.get("index_type"), "metric_type": idx_desc.get("metric_type"), "params": idx_desc.get("params"), }) except Exception: continue except Exception as e: logger.warning(f"Failed to list/describe indexes for {name}: {e}") detail = { "name": name, "description": desc.get("description", ""), "status": status, "entity_count": entity_count, "created_time": created_time, "updated_time": updated_time, "fields": fields, "enable_dynamic_field": desc.get("enable_dynamic_field", False), "consistency_level": desc.get("consistency_level"), "num_shards": desc.get("num_shards"), "num_partitions": desc.get("num_partitions"), "indices": indices, "properties": desc.get("properties"), "aliases": desc.get("aliases", []), } logger.info(f"成功获取集合 {name} 的详细信息") return detail def update_collection_description(self, name: str, description: str) -> Dict[str, Any]: """使用 alter_collection_properties 更新集合描述""" description = description or "" # 1. 更新集合 description(唯一修改点) self.client.alter_collection_properties( collection_name=name, properties={"collection.description": description}, ) # 2. 重新获取集合信息 desc = self.client.describe_collection(collection_name=name) print(desc) stats = self.client.get_collection_stats(collection_name=name) load_state = self.client.get_load_state(collection_name=name) # 3. 时间戳转换(Milvus TSO -> 物理时间) def ts_to_str(ts): if ts is None: return None ts_int = int(ts) physical_ms = ts_int >> 18 return datetime.fromtimestamp(physical_ms / 1000).strftime("%Y-%m-%d %H:%M:%S") created_time = ts_to_str(desc.get("created_timestamp")) updated_time = ts_to_str(desc.get("update_timestamp")) entity_count = stats.get("row_count") if isinstance(stats, dict) else None status = load_state.get("state") if isinstance(load_state, dict) else load_state return { "name": name, "status": status, "entity_count": entity_count, "description": desc.get("description", ""), "created_time": created_time, "updated_time": updated_time, } def hybrid_search(self, collection_name: str, query_text: str, top_k: int = 3, ranker_type: str = "weighted", dense_weight: float = 0.7, sparse_weight: float = 0.3): """ 混合搜索(参考 test_hybrid_v2.6.py 的实现) Args: param: 包含collection_name的参数字典 query_text: 查询文本 top_k: 返回结果数量 ranker_type: 重排序类型 "weighted" 或 "rrf" dense_weight: 密集向量权重(当ranker_type="weighted"时使用) sparse_weight: 稀疏向量权重(当ranker_type="weighted"时使用) Returns: List[Dict]: 搜索结果列表 """ try: collection_name = collection_name # 获取 vectorstore 实例(包含 Milvus 和 BM25BuiltInFunction) vectorstore = get_milvus_vectorstore( collection_name=collection_name, consistency_level="Strong" ) # 执行混合搜索 (完全按照 test_hybrid_v2.6.py 的逻辑) if ranker_type == "weighted": results = vectorstore.similarity_search( query=query_text, k=top_k, ranker_type="weighted", ranker_params={"weights": [dense_weight, sparse_weight]} ) else: # rrf results = vectorstore.similarity_search( query=query_text, k=top_k, ranker_type="rrf", ranker_params={"k": 60} ) # 格式化结果,保持与其他搜索方法一致 formatted_results = [] for doc in results: formatted_results.append({ 'id': doc.metadata.get('pk', 0), 'text_content': doc.page_content, 'metadata': doc.metadata, 'distance': 0.0, 'similarity': 1.0 }) logger.info(f"Hybrid search returned {len(formatted_results)} results") return formatted_results except Exception as e: logger.error(f"Error in hybrid search: {e}") # 回退到传统的向量搜索 logger.info("Falling back to traditional vector search") # 可选:单例 milvus_service = MilvusService() if __name__ == "__main__": # 推荐这样跑: # uv run python -m src.app.services.milvus_service import json service = MilvusService() # 测试混合搜索 hybrid_search print("=" * 50) print("测试混合检索 (Hybrid Search)") print("=" * 50) try: # 示例参数,需要根据实际情况修改 collection_name = "first_bfp_collection_status" query_text = "《公路水运工程临时用电技术规程》(JTT1499-2024)状态为现行" # 修改为实际查询内容 # 测试 weighted 模式 print("\n1. 测试 Weighted 重排序模式:") print(f" 集合: {collection_name}") print(f" 查询: {query_text}") print(f" 密集权重: 0.7, 稀疏权重: 0.3") results_weighted = service.hybrid_search( collection_name=collection_name, query_text=query_text, top_k=5, ranker_type="weighted", dense_weight=0.7, sparse_weight=0.3 ) print(f"\n 结果数量: {len(results_weighted)}") for i, result in enumerate(results_weighted, 1): print(f" [{i}] ID: {result.get('id')}, Text: {result.get('text_content')[:50]}...") # 测试 RRF 模式 print("\n2. 测试 RRF (Reciprocal Rank Fusion) 重排序模式:") print(f" 集合: {collection_name}") print(f" 查询: {query_text}") results_rrf = service.hybrid_search( collection_name=collection_name, query_text=query_text, top_k=5, ranker_type="rrf" ) print(f"\n 结果数量: {len(results_rrf)}") for i, result in enumerate(results_rrf, 1): print(f" [{i}] ID: {result.get('id')}, Text: {result.get('text_content')[:50]}...") print("\n✓ 混合检索测试完成") except Exception as e: print(f"\n✗ 混合检索测试失败: {e}") import traceback traceback.print_exc() # 也可以查看集合详情 print("\n" + "=" * 50) print("获取所有集合信息:") print("=" * 50) data = service.get_collection_details() for item in data: print(json.dumps(item, ensure_ascii=False, indent=2))