""" Milvus Service:业务层(直接用 manager.client 调 Milvus 原生方法) """ from __future__ import annotations import sys import os import logging import re import hashlib import time from typing import List, Dict, Any, Optional from datetime import datetime from app.base import get_milvus_manager, get_milvus_vectorstore, get_embedding_model from app.core.config import config_handler from langchain_core.documents import Document logger = logging.getLogger(__name__) # 默认集合名称 PARENT_COLLECTION_NAME = config_handler.get("admin_app", "PARENT_COLLECTION_NAME", "test_27_parent") CHILD_COLLECTION_NAME = config_handler.get("admin_app", "CHILD_COLLECTION_NAME", "test_27_child") class MilvusService: def __init__(self): self.client = get_milvus_manager().client # 获取embedding model self.emdmodel = get_embedding_model() # 默认向量维度 (Qwen3-Embedding-8B default) self.DENSE_DIM = 4096 def ensure_collections(self): """确保系统默认集合已创建""" collections = [PARENT_COLLECTION_NAME, CHILD_COLLECTION_NAME] for name in collections: self.ensure_collection_exists(name) async def insert_knowledge(self, content: str, doc_info: Dict[str, Any]): """将 Markdown 内容切分并入库 (支持父子段分表)""" try: doc_id = doc_info.get("doc_id") doc_name = doc_info.get("doc_name") doc_version = doc_info.get("doc_version", int(time.time())) tags = doc_info.get("tags", "") user_id = doc_info.get("user_id", "system") kb_method = doc_info.get("kb_method") target_collection = doc_info.get("collection_name") or PARENT_COLLECTION_NAME from langchain_text_splitters import RecursiveCharacterTextSplitter if kb_method == "parent_child": # --- 方案 A: 父子段分表入库 --- parent_col = f"{target_collection}_parent" child_col = f"{target_collection}_child" # 1. 切分父段 (较大块) parent_splitter = RecursiveCharacterTextSplitter( chunk_size=1000, chunk_overlap=100 ) parent_chunks = parent_splitter.split_text(content) parent_docs = [] child_docs = [] for p_idx, p_content in enumerate(parent_chunks): # 生成唯一的 parent_id p_id = hashlib.sha1(f"{doc_id}_p_{p_idx}".encode()).hexdigest() # 准备父段文档 (Metadata 不包含向量,仅用于检索回显) p_metadata = self._prepare_metadata(doc_info, p_id, p_idx, p_id) parent_docs.append(Document(page_content=p_content, metadata=p_metadata)) # 2. 在每个父段内部切分子段 (较小块) child_splitter = RecursiveCharacterTextSplitter( chunk_size=300, chunk_overlap=30 ) child_chunks = child_splitter.split_text(p_content) for c_idx, c_content in enumerate(child_chunks): # 子段的 parent_id 指向父段的 p_id c_metadata = self._prepare_metadata(doc_info, p_id, c_idx, p_id) child_docs.append(Document(page_content=c_content, metadata=c_metadata)) # 确保两个集合都存在 self.ensure_collection_exists(parent_col) self.ensure_collection_exists(child_col) # 分别入库 if parent_docs: get_milvus_vectorstore(parent_col).add_documents(parent_docs) if child_docs: get_milvus_vectorstore(child_col).add_documents(child_docs) logger.info(f"Successfully inserted parent-child chunks for {doc_name}: {len(parent_docs)} parents -> {len(child_docs)} children") else: # --- 常规单表入库逻辑 --- chunks = [] if kb_method == "length": splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) chunks = splitter.split_text(content) elif kb_method == "symbol": splitter = RecursiveCharacterTextSplitter( separators=["\n\n", "\n", "。", ";", "!", "?", "!", "?", ";"], chunk_size=500, chunk_overlap=0 ) chunks = splitter.split_text(content) else: chunks = [p.strip() for p in re.split(r"\n\s*\n+", content) if p.strip()] if not chunks: logger.warning(f"Document {doc_name} has no content chunks.") return documents = [] for idx, chunk in enumerate(chunks): p_id = hashlib.sha1(f"{doc_id}_{idx}".encode()).hexdigest() metadata = self._prepare_metadata(doc_info, p_id, idx, p_id) documents.append(Document(page_content=chunk, metadata=metadata)) self.ensure_collection_exists(target_collection) get_milvus_vectorstore(target_collection).add_documents(documents) logger.info(f"Successfully inserted {len(documents)} chunks for {doc_name} into {target_collection}") except Exception as e: logger.error(f"Error inserting knowledge into Milvus: {e}") raise def _prepare_metadata(self, doc_info: Dict[str, Any], p_id: str, index: int, parent_ref_id: str) -> Dict[str, Any]: """统一准备元数据""" doc_id = doc_info.get("doc_id") doc_name = doc_info.get("doc_name") doc_version = doc_info.get("doc_version", int(time.time())) tags = doc_info.get("tags", "") user_id = doc_info.get("user_id", "system") return { "document_id": doc_id, "parent_id": parent_ref_id, "index": index, "tag_list": tags, "permission": {}, "is_deleted": 0, "created_by": user_id, "created_time": int(time.time() * 1000), "updated_by": user_id, "updated_time": int(time.time() * 1000), "metadata": { "doc_name": doc_name, "doc_version": doc_version, "outline_path": "" } } def ensure_collection_exists(self, name: str): """确保指定名称的集合存在,不存在则按默认 Schema 创建""" from pymilvus import DataType, Function, FunctionType # 1. 如果不存在,则创建集合 if not self.client.has_collection(name): logger.info(f"Creating collection: {name}") schema = self.client.create_schema(auto_id=True, enable_dynamic_field=False) schema.add_field("pk", DataType.INT64, is_primary=True, auto_id=True) schema.add_field("text", DataType.VARCHAR, max_length=65535, enable_analyzer=True) schema.add_field("dense", DataType.FLOAT_VECTOR, dim=self.DENSE_DIM) schema.add_field("sparse", DataType.SPARSE_FLOAT_VECTOR) schema.add_field("document_id", DataType.VARCHAR, max_length=256) schema.add_field("parent_id", DataType.VARCHAR, max_length=256) schema.add_field("index", DataType.INT64) schema.add_field("tag_list", DataType.VARCHAR, max_length=2048) schema.add_field("permission", DataType.JSON, nullable=True) schema.add_field("metadata", DataType.JSON, nullable=True) schema.add_field("is_deleted", DataType.INT64, default_value=0) schema.add_field("created_by", DataType.VARCHAR, max_length=256, nullable=True) schema.add_field("created_time", DataType.INT64) schema.add_field("updated_by", DataType.VARCHAR, max_length=256, nullable=True) schema.add_field("updated_time", DataType.INT64) schema.add_function( Function( name="bm25_fn", input_field_names=["text"], output_field_names=["sparse"], function_type=FunctionType.BM25, ) ) self.client.create_collection(collection_name=name, schema=schema) # 2. 检查并补全索引 # 获取集合的描述信息以检查字段是否存在 desc = self.client.describe_collection(collection_name=name) fields_in_collection = [f.get("name") for f in desc.get("fields", [])] existing_indexes = self.client.list_indexes(collection_name=name) index_params = self.client.prepare_index_params() needs_index = False # 只有当字段存在且没有索引时才添加 if "dense" in fields_in_collection and "dense_idx" not in existing_indexes: index_params.add_index( field_name="dense", index_name="dense_idx", index_type="AUTOINDEX", metric_type="COSINE", ) needs_index = True if "sparse" in fields_in_collection and "bm25_idx" not in existing_indexes: index_params.add_index( field_name="sparse", index_name="bm25_idx", index_type="SPARSE_INVERTED_INDEX", metric_type="BM25", params={"inverted_index_algo": "DAAT_MAXSCORE"}, ) needs_index = True if "permission" in fields_in_collection and "permission" not in existing_indexes: index_params.add_index( field_name="permission", index_type="INVERTED", params={"json_cast_type": "VARCHAR"} ) needs_index = True if "metadata" in fields_in_collection and "metadata" not in existing_indexes: index_params.add_index( field_name="metadata", index_type="INVERTED", params={"json_cast_type": "VARCHAR"} ) needs_index = True if needs_index: logger.info(f"Creating missing indexes for collection: {name}") try: self.client.create_index(collection_name=name, index_params=index_params) except Exception as e: logger.error(f"Failed to create index for {name}: {e}") self.client.load_collection(collection_name=name) def create_collection(self, name: str, dimension: int = None, description: str = "", fields: List[Dict] = None) -> None: """ 创建 Milvus 集合 :param dimension: 向量维度,如果为None则使用默认值 :param fields: 自定义字段列表,每个元素为 {"name": "age", "type": "INT64", ...} """ # 使用默认维度 if dimension is None: dimension = self.DENSE_DIM if self.client.has_collection(name): logger.info(f"Collection {name} already exists.") return # 如果有自定义字段,使用 schema 创建 if fields: from pymilvus import MilvusClient, DataType, Function, FunctionType # 1. 创建 Schema schema = MilvusClient.create_schema( auto_id=True, enable_dynamic_field=True, description=description ) # 检查字段中是否定义了主键 has_primary = any(f.get("is_primary") for f in fields) if not has_primary: # 如果没有定义主键,添加默认主键 schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True, auto_id=True) # 检查是否有默认向量列,如果没有则添加 (兼容旧逻辑,但如果fields里有dense则不添加) has_vector = any(f.get("type") == "FLOAT_VECTOR" for f in fields) if not has_vector: schema.add_field(field_name="dense", datatype=DataType.FLOAT_VECTOR, dim=dimension) # 3. 添加用户自定义字段 type_map = { "BOOL": DataType.BOOL, "INT8": DataType.INT8, "INT16": DataType.INT16, "INT32": DataType.INT32, "INT64": DataType.INT64, "FLOAT": DataType.FLOAT, "DOUBLE": DataType.DOUBLE, "VARCHAR": DataType.VARCHAR, "JSON": DataType.JSON, "FLOAT_VECTOR": DataType.FLOAT_VECTOR, "SPARSE_FLOAT_VECTOR": DataType.SPARSE_FLOAT_VECTOR, "BM25": DataType.SPARSE_FLOAT_VECTOR # BM25 特殊处理,映射为稀疏向量 } bm25_field = None text_field_name = "text" # 默认文本字段名 for f in fields: field_type_str = f.get("type", "").upper() dtype = type_map.get(field_type_str) if not dtype: continue # 记录文本字段名,供BM25使用 if f.get("name") in ["text", "content", "chunk"]: text_field_name = f.get("name") kwargs = { "field_name": f.get("name"), "datatype": dtype, "description": f.get("description", "") } if f.get("is_primary"): kwargs["is_primary"] = True kwargs["auto_id"] = True # 假设主键都是自增 if dtype == DataType.VARCHAR: kwargs["max_length"] = f.get("max_length", 65535) # 关键修复:如果要被 BM25 引用,必须启用 analyzer if f.get("name") in ["text", "content", "chunk"]: kwargs["enable_analyzer"] = True if dtype == DataType.FLOAT_VECTOR: kwargs["dim"] = dimension # 使用传入的 dimension schema.add_field(**kwargs) # 如果是 BM25 类型,记录下来以便后续添加 Function if field_type_str == "BM25": bm25_field = f.get("name") # 处理 BM25 Function if bm25_field: try: schema.add_function(Function( name="bm25_fn", input_field_names=[text_field_name], output_field_names=[bm25_field], function_type=FunctionType.BM25 )) logger.info(f"Added BM25 function mapping {text_field_name} -> {bm25_field}") except Exception as e: logger.error(f"Failed to add BM25 function: {e}") # 4. 准备索引参数 index_params = self.client.prepare_index_params() # 5. 为所有向量字段添加索引 for f in fields: ftype = f.get("type", "").upper() if ftype == "FLOAT_VECTOR": index_params.add_index( field_name=f.get("name"), index_type="AUTOINDEX", metric_type="IP" # [Modified] 更改为 IP (内积),通常对规范化向量效果更好,与 COSINE 类似但更简单 ) elif ftype == "BM25" or ftype == "SPARSE_FLOAT_VECTOR": index_params.add_index( field_name=f.get("name"), index_type="SPARSE_INVERTED_INDEX", # 稀疏向量索引 metric_type="BM25" ) # 6. 为自定义标量字段添加索引 for f in fields: ftype = f.get("type", "").upper() if ftype in ["VARCHAR", "INT64", "INT32", "BOOL"] and not f.get("is_primary"): # 排除主键,主键自动索引 index_params.add_index( field_name=f.get("name"), index_type="INVERTED" ) elif ftype == "JSON": # Milvus 2.4+ JSON 索引必须指定 json_cast_type # 这里为 JSON 字段添加默认索引,以便支持查询 index_params.add_index( field_name=f.get("name"), index_type="INVERTED", params={"json_cast_type": "VARCHAR"} ) # 7. 创建集合 self.client.create_collection( collection_name=name, schema=schema, index_params=index_params ) else: # 使用简化的 create_collection API self.client.create_collection( collection_name=name, dimension=dimension, description=description, auto_id=True, # 自动生成 ID id_type="int", # ID 类型 metric_type="IP" # [Modified] 默认使用 IP ) logger.info(f"Created collection {name} with dimension {dimension}") def drop_collection(self, name: str) -> None: """删除 Milvus 集合""" if self.client.has_collection(name): self.client.drop_collection(name) logger.info(f"Dropped collection {name}") def has_collection(self, name: str) -> bool: """检查集合是否存在""" return self.client.has_collection(name) def get_collection_details(self) -> List[Dict[str, Any]]: """ 获取所有 Collections 详细信息 """ details: List[Dict[str, Any]] = [] names = self.client.list_collections() for name in names: desc = self.client.describe_collection(collection_name=name) stats = self.client.get_collection_stats(collection_name=name) load_state = self.client.get_load_state(collection_name=name) # ===== 时间戳转换(按你指定写法,无封装)===== created_time = None updated_time = None if desc.get("created_timestamp") is not None: ts_int = int(desc["created_timestamp"]) physical_ms = ts_int >> 18 created_time = datetime.fromtimestamp(physical_ms / 1000).strftime("%Y-%m-%d %H:%M:%S") if desc.get("update_timestamp") is not None: ts_int = int(desc["update_timestamp"]) physical_ms = ts_int >> 18 updated_time = datetime.fromtimestamp(physical_ms / 1000).strftime("%Y-%m-%d %H:%M:%S") # ===== 数量:不保底(要求返回结构必须有 row_count)===== entity_count = stats["row_count"] # ===== 状态:不保底(要求返回结构必须有 state)===== status = load_state["state"] details.append( { "name": name, "status": status, "entity_count": entity_count, "description": desc.get("description", ""), "created_time": created_time, "updated_time": updated_time, } ) logger.info(f"成功获取Collections详细信息,共{len(details)}个") return details def get_collection_state(self, name: str) -> str: """获取集合加载状态""" try: load_state = self.client.get_load_state(collection_name=name) state = load_state.get("state") if isinstance(load_state, dict) else load_state return state except Exception as e: logger.error(f"Failed to get collection state for {name}: {e}") return "Unknown" def set_collection_state(self, name: str, action: str) -> Dict[str, Any]: """ 改变指定 Collection 的加载状态。 参数: - name: 集合名称 - action: 操作,取值 'load' 或 'release' 返回: - 包含集合名称和当前状态的字典,例如: {"name": name, "state": "Loaded"} """ action_norm = (action or "").strip().lower() if action_norm not in {"load", "release"}: raise ValueError("action 必须为 'load' 或 'release'") # 执行加载/释放 if action_norm == "load": self.client.load_collection(collection_name=name) else: self.client.release_collection(collection_name=name) # 返回最新状态 load_state = self.client.get_load_state(collection_name=name) state = load_state.get("state") if isinstance(load_state, dict) else load_state result = {"name": name, "state": state, "action": action_norm} logger.info(f"集合 {name} 状态更新为 {state} (action={action_norm})") return result def delete_collection_if_empty(self, name: str) -> Dict[str, Any]: """仅当集合内容为空时删除集合,否则抛出异常""" stats = self.client.get_collection_stats(collection_name=name) row_count = stats.get("row_count") if isinstance(stats, dict) else None if row_count is None: raise ValueError("无法获取集合行数,禁止删除") if int(row_count) > 0: raise ValueError("集合内容不为空,不能删除") self.client.drop_collection(collection_name=name) logger.info(f"集合 {name} 已删除") return {"name": name, "deleted": True} def get_collection_detail(self, name: str) -> Dict[str, Any]: """获取单个集合的详细信息,包含schema、索引等所有desc字段""" desc = self.client.describe_collection(collection_name=name) stats = self.client.get_collection_stats(collection_name=name) load_state = self.client.get_load_state(collection_name=name) # 时间戳转换 created_time = None updated_time = None if desc.get("created_timestamp") is not None: ts_int = int(desc["created_timestamp"]) physical_ms = ts_int >> 18 created_time = datetime.fromtimestamp(physical_ms / 1000).strftime("%Y-%m-%d %H:%M:%S") if desc.get("update_timestamp") is not None: ts_int = int(desc["update_timestamp"]) physical_ms = ts_int >> 18 updated_time = datetime.fromtimestamp(physical_ms / 1000).strftime("%Y-%m-%d %H:%M:%S") entity_count = stats.get("row_count", 0) status = load_state.get("state") if isinstance(load_state, dict) else load_state # 提取字段schema fields = [] if "fields" in desc: for field in desc["fields"]: field_info = { "name": field.get("name"), "type": str(field.get("type")), "description": field.get("description", ""), "is_primary": field.get("is_primary", False), "auto_id": field.get("auto_id"), } # 向量维度 if "params" in field and "dim" in field["params"]: field_info["dim"] = field["params"]["dim"] # 字符串长度 if "params" in field and "max_length" in field["params"]: field_info["max_length"] = field["params"]["max_length"] # 其他params if "params" in field: field_info["params"] = field["params"] fields.append(field_info) # 提取索引信息 indices = [] # 尝试从 describe_collection 结果中获取 (兼容旧逻辑) if "indexes" in desc: for idx in desc["indexes"]: index_info = { "field_name": idx.get("field_name"), "index_name": idx.get("index_name"), "index_type": idx.get("index_type"), "metric_type": idx.get("metric_type"), "params": idx.get("params"), } indices.append(index_info) # 如果没有获取到索引信息,尝试主动查询 list_indexes if not indices: try: # 获取索引列表 (通常返回索引名称列表) index_names = self.client.list_indexes(collection_name=name) if index_names: for idx_name in index_names: try: # 获取索引详情 idx_desc = self.client.describe_index(collection_name=name, index_name=idx_name) if idx_desc: indices.append({ "field_name": idx_desc.get("field_name"), "index_name": idx_desc.get("index_name"), "index_type": idx_desc.get("index_type"), "metric_type": idx_desc.get("metric_type"), "params": idx_desc.get("params"), }) except Exception: continue except Exception as e: logger.warning(f"Failed to list/describe indexes for {name}: {e}") detail = { "name": name, "description": desc.get("description", ""), "status": status, "entity_count": entity_count, "created_time": created_time, "updated_time": updated_time, "fields": fields, "enable_dynamic_field": desc.get("enable_dynamic_field", False), "consistency_level": desc.get("consistency_level"), "num_shards": desc.get("num_shards"), "num_partitions": desc.get("num_partitions"), "indices": indices, "properties": desc.get("properties"), "aliases": desc.get("aliases", []), } logger.info(f"成功获取集合 {name} 的详细信息") return detail def update_collection_description(self, name: str, description: str) -> Dict[str, Any]: """使用 alter_collection_properties 更新集合描述""" description = description or "" # 1. 更新集合 description(唯一修改点) self.client.alter_collection_properties( collection_name=name, properties={"collection.description": description}, ) # 2. 重新获取集合信息 desc = self.client.describe_collection(collection_name=name) print(desc) stats = self.client.get_collection_stats(collection_name=name) load_state = self.client.get_load_state(collection_name=name) # 3. 时间戳转换(Milvus TSO -> 物理时间) def ts_to_str(ts): if ts is None: return None ts_int = int(ts) physical_ms = ts_int >> 18 return datetime.fromtimestamp(physical_ms / 1000).strftime("%Y-%m-%d %H:%M:%S") created_time = ts_to_str(desc.get("created_timestamp")) updated_time = ts_to_str(desc.get("update_timestamp")) entity_count = stats.get("row_count") if isinstance(stats, dict) else None status = load_state.get("state") if isinstance(load_state, dict) else load_state return { "name": name, "status": status, "entity_count": entity_count, "description": desc.get("description", ""), "created_time": created_time, "updated_time": updated_time, } def hybrid_search(self, collection_name: str, query_text: str, top_k: int = 3, ranker_type: str = "weighted", dense_weight: float = 0.7, sparse_weight: float = 0.3, expr: str = None): """ 混合搜索(参考 test_hybrid_v2.6.py 的实现) Args: param: 包含collection_name的参数字典 query_text: 查询文本 top_k: 返回结果数量 ranker_type: 重排序类型 "weighted" 或 "rrf" dense_weight: 密集向量权重(当ranker_type="weighted"时使用) sparse_weight: 稀疏向量权重(当ranker_type="weighted"时使用) expr: 过滤表达式 (Metadata Filtering) Returns: List[Dict]: 搜索结果列表 """ try: collection_name = collection_name # 确保集合已加载 self.client.load_collection(collection_name) # 获取 vectorstore 实例(包含 Milvus 和 BM25BuiltInFunction) vectorstore = get_milvus_vectorstore( collection_name=collection_name, consistency_level="Strong" ) # 执行混合搜索 (完全按照 test_hybrid_v2.6.py 的逻辑) # 注意:LangChain Milvus vectorstore 的 similarity_search 支持 expr 参数用于过滤 if ranker_type == "weighted": results = vectorstore.similarity_search( query=query_text, k=top_k, expr=expr, ranker_type="weighted", ranker_params={"weights": [dense_weight, sparse_weight]} ) else: # rrf results = vectorstore.similarity_search( query=query_text, k=top_k, expr=expr, ranker_type="rrf", ranker_params={"k": 60} ) # 格式化结果,保持与其他搜索方法一致 formatted_results = [] for doc in results: formatted_results.append({ 'id': doc.metadata.get('pk', 0), 'text_content': doc.page_content, 'metadata': doc.metadata, 'distance': 0.0, 'similarity': 1.0 }) logger.info(f"Hybrid search returned {len(formatted_results)} results") return formatted_results except Exception as e: logger.error(f"Error in hybrid search: {e}") # 回退到传统的向量搜索 logger.info("Falling back to traditional vector search") return [] # 可选:单例 milvus_service = MilvusService() if __name__ == "__main__": # 推荐这样跑: # uv run python -m src.app.services.milvus_service import json service = MilvusService() # 测试混合搜索 hybrid_search print("=" * 50) print("测试混合检索 (Hybrid Search)") print("=" * 50) try: # 示例参数,需要根据实际情况修改 collection_name = "first_bfp_collection_status" query_text = "《公路水运工程临时用电技术规程》(JTT1499-2024)状态为现行" # 修改为实际查询内容 # 测试 weighted 模式 print("\n1. 测试 Weighted 重排序模式:") print(f" 集合: {collection_name}") print(f" 查询: {query_text}") print(f" 密集权重: 0.7, 稀疏权重: 0.3") results_weighted = service.hybrid_search( collection_name=collection_name, query_text=query_text, top_k=5, ranker_type="weighted", dense_weight=0.7, sparse_weight=0.3 ) print(f"\n 结果数量: {len(results_weighted)}") for i, result in enumerate(results_weighted, 1): print(f" [{i}] ID: {result.get('id')}, Text: {result.get('text_content')[:50]}...") # 测试 RRF 模式 print("\n2. 测试 RRF (Reciprocal Rank Fusion) 重排序模式:") print(f" 集合: {collection_name}") print(f" 查询: {query_text}") results_rrf = service.hybrid_search( collection_name=collection_name, query_text=query_text, top_k=5, ranker_type="rrf" ) print(f"\n 结果数量: {len(results_rrf)}") for i, result in enumerate(results_rrf, 1): print(f" [{i}] ID: {result.get('id')}, Text: {result.get('text_content')[:50]}...") print("\n✓ 混合检索测试完成") except Exception as e: print(f"\n✗ 混合检索测试失败: {e}") import traceback traceback.print_exc() # 也可以查看集合详情 print("\n" + "=" * 50) print("获取所有集合信息:") print("=" * 50) data = service.get_collection_details() for item in data: print(json.dumps(item, ensure_ascii=False, indent=2))