|
|
@@ -1,305 +1,553 @@
|
|
|
+"""
|
|
|
+Milvus Service:业务层(直接用 manager.client 调 Milvus 原生方法)
|
|
|
+"""
|
|
|
+from __future__ import annotations
|
|
|
+
|
|
|
+import sys
|
|
|
+import os
|
|
|
+
|
|
|
+# 添加src目录到Python路径
|
|
|
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../..'))
|
|
|
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '../../..'))
|
|
|
|
|
|
-import time
|
|
|
-import re
|
|
|
-import hashlib
|
|
|
import logging
|
|
|
-import json
|
|
|
-from typing import List, Dict, Any, Tuple, Optional
|
|
|
-from langchain_core.documents import Document
|
|
|
-from langchain_openai import OpenAIEmbeddings
|
|
|
-from pymilvus import MilvusClient, DataType, Function, FunctionType
|
|
|
+from typing import List, Dict, Any
|
|
|
+from datetime import datetime
|
|
|
|
|
|
-from app.core.config import config_handler
|
|
|
+from app.base import get_milvus_manager, get_milvus_vectorstore, get_embedding_model
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
+
|
|
|
class MilvusService:
|
|
|
- """Milvus 向量库服务类,实现父子块切分与混合检索存储"""
|
|
|
-
|
|
|
- def __init__(self, uri: str, db_name: str, parent_collection: str, child_collection: str):
|
|
|
- self.client = MilvusClient(uri=uri, db_name=db_name)
|
|
|
- self.parent_collection = parent_collection
|
|
|
- self.child_collection = child_collection
|
|
|
- self.emb = self._get_embeddings()
|
|
|
-
|
|
|
- # 配置参数
|
|
|
- self.PARENT_MAX_CHARS = 6000
|
|
|
+ def __init__(self):
|
|
|
+ self.client = get_milvus_manager().client
|
|
|
+ # 获取embedding model
|
|
|
+ self.emdmodel = get_embedding_model()
|
|
|
+ # 默认向量维度 (Qwen3-Embedding-8B default)
|
|
|
self.DENSE_DIM = 4096
|
|
|
- self.H1_RE = re.compile(r"^#\s+(.+?)\s*$", re.MULTILINE)
|
|
|
- self.BLANK_SPLIT_RE = re.compile(r"\n\s*\n+")
|
|
|
-
|
|
|
- # 确保集合已创建
|
|
|
- self.ensure_collections()
|
|
|
|
|
|
- def has_collection(self, collection_name: str) -> bool:
|
|
|
- """检查集合是否存在"""
|
|
|
- return self.client.has_collection(collection_name=collection_name)
|
|
|
-
|
|
|
- def _get_embeddings(self) -> OpenAIEmbeddings:
|
|
|
- """获取 Embedding 模型配置"""
|
|
|
- return OpenAIEmbeddings(
|
|
|
- base_url=config_handler.get("admin_app", "EMBEDDING_BASE_URL", "http://192.168.91.253:9003/v1"),
|
|
|
- model=config_handler.get("admin_app", "EMBEDDING_MODEL", "Qwen3-Embedding-8B"),
|
|
|
- api_key=config_handler.get("admin_app", "EMBEDDING_API_KEY", "dummy"),
|
|
|
- )
|
|
|
+ def create_collection(self, name: str, dimension: int = None, description: str = "", fields: List[Dict] = None) -> None:
|
|
|
+ """
|
|
|
+ 创建 Milvus 集合
|
|
|
+ :param dimension: 向量维度,如果为None则使用默认值
|
|
|
+ :param fields: 自定义字段列表,每个元素为 {"name": "age", "type": "INT64", ...}
|
|
|
+ """
|
|
|
+ # 使用默认维度
|
|
|
+ if dimension is None:
|
|
|
+ dimension = self.DENSE_DIM
|
|
|
|
|
|
- # --- 切分工具方法 ---
|
|
|
-
|
|
|
- def _split_md_by_blank_lines(self, md: str) -> List[str]:
|
|
|
- md = md.replace("\r\n", "\n").replace("\r", "\n")
|
|
|
- parts = self.BLANK_SPLIT_RE.split(md)
|
|
|
- return [p.strip() for p in parts if p.strip()]
|
|
|
-
|
|
|
- def _is_heading_chunk(self, chunk: str) -> Optional[Tuple[int, str]]:
|
|
|
- first_line = chunk.split("\n", 1)[0].strip()
|
|
|
- m = re.match(r"^(#{1,6})\s+(.+?)\s*$", first_line)
|
|
|
- if not m:
|
|
|
- return None
|
|
|
- return len(m.group(1)), m.group(2).strip()
|
|
|
-
|
|
|
- def _split_md_by_h1_sections(self, md: str) -> List[Tuple[str, str]]:
|
|
|
- """按一级标题切分父块"""
|
|
|
- md = md.replace("\r\n", "\n").replace("\r", "\n")
|
|
|
- matches = list(self.H1_RE.finditer(md))
|
|
|
- if not matches:
|
|
|
- txt = md.strip()
|
|
|
- return [("__NO_H1__", txt)] if txt else []
|
|
|
-
|
|
|
- sections = []
|
|
|
- # 检查第一个#之前的内容
|
|
|
- first_match_start = matches[0].start()
|
|
|
- preamble = md[:first_match_start].strip()
|
|
|
- if preamble:
|
|
|
- sections.append(("__PREAMBLE__", preamble))
|
|
|
-
|
|
|
- for i, m in enumerate(matches):
|
|
|
- title = m.group(1).strip()
|
|
|
- start = m.start()
|
|
|
- end = matches[i + 1].start() if i + 1 < len(matches) else len(md)
|
|
|
- sec = md[start:end].strip()
|
|
|
- if sec:
|
|
|
- sections.append((title, sec))
|
|
|
- return sections
|
|
|
-
|
|
|
- def _make_parent_id(self, doc_id: str, doc_version: int, doc_name: str, h1_title: str, parent_seq: int) -> str:
|
|
|
- """生成稳定的 parent_id"""
|
|
|
- raw = f"{doc_id}|{doc_version}|{doc_name}|{parent_seq}|{h1_title}".encode("utf-8")
|
|
|
- return hashlib.sha1(raw).hexdigest()
|
|
|
-
|
|
|
- def _split_text_by_max_chars(self, text: str, max_chars: int) -> List[str]:
|
|
|
- """父段过长时切片"""
|
|
|
- text = (text or "").strip()
|
|
|
- if not text or len(text) <= max_chars:
|
|
|
- return [text] if text else []
|
|
|
-
|
|
|
- chunks = self._split_md_by_blank_lines(text)
|
|
|
- result = []
|
|
|
- current_slice = ""
|
|
|
-
|
|
|
- for chunk in chunks:
|
|
|
- if len(chunk) > max_chars:
|
|
|
- if current_slice.strip():
|
|
|
- result.append(current_slice.strip())
|
|
|
- current_slice = ""
|
|
|
- start = 0
|
|
|
- while start < len(chunk):
|
|
|
- result.append(chunk[start : start + max_chars].strip())
|
|
|
- start += max_chars
|
|
|
- else:
|
|
|
- test_slice = current_slice + "\n\n" + chunk if current_slice else chunk
|
|
|
- if len(test_slice) <= max_chars:
|
|
|
- current_slice = test_slice
|
|
|
- else:
|
|
|
- if current_slice.strip():
|
|
|
- result.append(current_slice.strip())
|
|
|
- current_slice = chunk
|
|
|
+ if self.client.has_collection(name):
|
|
|
+ logger.info(f"Collection {name} already exists.")
|
|
|
+ return
|
|
|
|
|
|
- if current_slice.strip():
|
|
|
- result.append(current_slice.strip())
|
|
|
- return [s for s in result if s]
|
|
|
-
|
|
|
- # --- 核心业务逻辑 ---
|
|
|
-
|
|
|
- def ensure_collections(self):
|
|
|
- """确保父子 Collection 已创建并配置索引"""
|
|
|
- for col_name in [self.parent_collection, self.child_collection]:
|
|
|
- if not self.client.has_collection(collection_name=col_name):
|
|
|
- schema = self.client.create_schema(auto_id=True, enable_dynamic_fields=False)
|
|
|
- schema.add_field("pk", DataType.INT64, is_primary=True, auto_id=True)
|
|
|
- schema.add_field("text", DataType.VARCHAR, max_length=65535, enable_analyzer=True)
|
|
|
- schema.add_field("dense", DataType.FLOAT_VECTOR, dim=self.DENSE_DIM)
|
|
|
- schema.add_field("sparse", DataType.SPARSE_FLOAT_VECTOR)
|
|
|
- schema.add_field("document_id", DataType.VARCHAR, max_length=256)
|
|
|
- schema.add_field("parent_id", DataType.VARCHAR, max_length=256)
|
|
|
- schema.add_field("index", DataType.INT64)
|
|
|
- schema.add_field("tag_list", DataType.VARCHAR, max_length=2048)
|
|
|
- schema.add_field("permission", DataType.JSON)
|
|
|
- schema.add_field("metadata", DataType.JSON)
|
|
|
- schema.add_field("is_deleted", DataType.INT64)
|
|
|
- schema.add_field("created_by", DataType.VARCHAR, max_length=256)
|
|
|
- schema.add_field("created_time", DataType.INT64)
|
|
|
- schema.add_field("updated_by", DataType.VARCHAR, max_length=256)
|
|
|
- schema.add_field("updated_time", DataType.INT64)
|
|
|
-
|
|
|
- schema.add_function(Function(
|
|
|
- name="bm25_fn",
|
|
|
- input_field_names=["text"],
|
|
|
- output_field_names=["sparse"],
|
|
|
- function_type=FunctionType.BM25,
|
|
|
- ))
|
|
|
-
|
|
|
- self.client.create_collection(collection_name=col_name, schema=schema)
|
|
|
-
|
|
|
- index_params = self.client.prepare_index_params()
|
|
|
- index_params.add_index(field_name="dense", index_name="dense_idx", index_type="AUTOINDEX", metric_type="COSINE")
|
|
|
- index_params.add_index(field_name="sparse", index_name="bm25_idx", index_type="SPARSE_INVERTED_INDEX", metric_type="BM25", params={"inverted_index_algo": "DAAT_MAXSCORE"})
|
|
|
- self.client.create_index(collection_name=col_name, index_params=index_params)
|
|
|
+ # 如果有自定义字段,使用 schema 创建
|
|
|
+ if fields:
|
|
|
+ from pymilvus import MilvusClient, DataType, Function, FunctionType
|
|
|
|
|
|
- self.client.load_collection(collection_name=col_name)
|
|
|
-
|
|
|
- async def insert_knowledge(self, md_text: str, doc_info: Dict[str, Any]):
|
|
|
- """执行切分、向量化并存入 Milvus"""
|
|
|
- doc_id = doc_info['doc_id']
|
|
|
- doc_name = doc_info.get('doc_name', 'unknown')
|
|
|
- doc_version = doc_info.get('doc_version', 20260127)
|
|
|
- tag_list = str(doc_info.get('tags') or '')
|
|
|
-
|
|
|
- # 公共字段准备
|
|
|
- created_by = doc_info.get('created_by', 'system')
|
|
|
- created_time = doc_info.get('created_time', int(time.time() * 1000))
|
|
|
- updated_by = doc_info.get('updated_by', 'system')
|
|
|
- updated_time = doc_info.get('updated_time', int(time.time() * 1000))
|
|
|
- permission = doc_info.get('permission', {})
|
|
|
-
|
|
|
- try:
|
|
|
- # 1. 幂等处理:清理旧数据
|
|
|
- try:
|
|
|
- self.client.delete(collection_name=self.parent_collection, filter=f"document_id == '{doc_id}'")
|
|
|
- self.client.delete(collection_name=self.child_collection, filter=f"document_id == '{doc_id}'")
|
|
|
- except Exception as e:
|
|
|
- logger.warning(f"清理旧数据失败 (doc_id: {doc_id}): {e}")
|
|
|
- # 继续执行,可能是第一次入库
|
|
|
+ # 1. 创建 Schema
|
|
|
+ schema = MilvusClient.create_schema(
|
|
|
+ auto_id=True,
|
|
|
+ enable_dynamic_field=True,
|
|
|
+ description=description
|
|
|
+ )
|
|
|
+
|
|
|
+ # 检查字段中是否定义了主键
|
|
|
+ has_primary = any(f.get("is_primary") for f in fields)
|
|
|
+ if not has_primary:
|
|
|
+ # 如果没有定义主键,添加默认主键
|
|
|
+ schema.add_field(field_name="id", datatype=DataType.INT64, is_primary=True, auto_id=True)
|
|
|
+
|
|
|
+ # 检查是否有默认向量列,如果没有则添加 (兼容旧逻辑,但如果fields里有vector则不添加)
|
|
|
+ has_vector = any(f.get("type") == "FLOAT_VECTOR" for f in fields)
|
|
|
+ if not has_vector:
|
|
|
+ schema.add_field(field_name="vector", datatype=DataType.FLOAT_VECTOR, dim=dimension)
|
|
|
+
|
|
|
+ # 3. 添加用户自定义字段
|
|
|
+ type_map = {
|
|
|
+ "BOOL": DataType.BOOL,
|
|
|
+ "INT8": DataType.INT8,
|
|
|
+ "INT16": DataType.INT16,
|
|
|
+ "INT32": DataType.INT32,
|
|
|
+ "INT64": DataType.INT64,
|
|
|
+ "FLOAT": DataType.FLOAT,
|
|
|
+ "DOUBLE": DataType.DOUBLE,
|
|
|
+ "VARCHAR": DataType.VARCHAR,
|
|
|
+ "JSON": DataType.JSON,
|
|
|
+ "FLOAT_VECTOR": DataType.FLOAT_VECTOR,
|
|
|
+ "SPARSE_FLOAT_VECTOR": DataType.SPARSE_FLOAT_VECTOR,
|
|
|
+ "BM25": DataType.SPARSE_FLOAT_VECTOR # BM25 特殊处理,映射为稀疏向量
|
|
|
+ }
|
|
|
+
|
|
|
+ bm25_field = None
|
|
|
+ text_field_name = "text" # 默认文本字段名
|
|
|
|
|
|
- # 2. 切分父子块
|
|
|
- try:
|
|
|
- parent_sections = self._split_md_by_h1_sections(md_text)
|
|
|
- parent_entities = []
|
|
|
- child_entities = []
|
|
|
+ for f in fields:
|
|
|
+ field_type_str = f.get("type", "").upper()
|
|
|
+ dtype = type_map.get(field_type_str)
|
|
|
+ if not dtype:
|
|
|
+ continue
|
|
|
|
|
|
- # 预生成所有 parent_id
|
|
|
- parent_seq_to_id = {}
|
|
|
- for seq, (title, _) in enumerate(parent_sections):
|
|
|
- parent_seq_to_id[seq] = self._make_parent_id(doc_id, doc_version, doc_name, title, seq)
|
|
|
-
|
|
|
- # 3. 处理子块
|
|
|
- for seq, (h1_title, sec_text) in enumerate(parent_sections):
|
|
|
- p_id = parent_seq_to_id[seq]
|
|
|
- chunks = self._split_md_by_blank_lines(sec_text)
|
|
|
- heading_path = []
|
|
|
-
|
|
|
- for c_idx, chunk in enumerate(chunks):
|
|
|
- h_info = self._is_heading_chunk(chunk)
|
|
|
- if h_info:
|
|
|
- level, title = h_info
|
|
|
- heading_path = heading_path[:level-1] + [title]
|
|
|
-
|
|
|
- outline_path = " > ".join(heading_path)
|
|
|
-
|
|
|
- child_entities.append({
|
|
|
- "text": chunk,
|
|
|
- "is_deleted": 0,
|
|
|
- "parent_id": p_id,
|
|
|
- "document_id": doc_id,
|
|
|
- "index": int(c_idx),
|
|
|
- "tag_list": tag_list,
|
|
|
- "permission": permission,
|
|
|
- "metadata": {
|
|
|
- "doc_name": doc_name,
|
|
|
- "outline_path": outline_path,
|
|
|
- "doc_version": doc_version
|
|
|
- },
|
|
|
- "created_by": created_by,
|
|
|
- "created_time": created_time,
|
|
|
- "updated_by": updated_by,
|
|
|
- "updated_time": updated_time
|
|
|
- })
|
|
|
-
|
|
|
- # 4. 处理父块
|
|
|
- for seq, (h1_title, sec_text) in enumerate(parent_sections):
|
|
|
- p_id = parent_seq_to_id[seq]
|
|
|
- slices = self._split_text_by_max_chars(sec_text, self.PARENT_MAX_CHARS)
|
|
|
- for s_idx, slice_text in enumerate(slices):
|
|
|
- parent_entities.append({
|
|
|
- "text": slice_text,
|
|
|
- "is_deleted": 0,
|
|
|
- "parent_id": p_id,
|
|
|
- "document_id": doc_id,
|
|
|
- "index": int(seq),
|
|
|
- "tag_list": tag_list,
|
|
|
- "permission": permission,
|
|
|
- "metadata": {
|
|
|
- "doc_name": doc_name,
|
|
|
- "outline_path": h1_title if h1_title not in ["__PREAMBLE__", "__NO_H1__"] else doc_name,
|
|
|
- "doc_version": doc_version
|
|
|
- },
|
|
|
- "created_by": created_by,
|
|
|
- "created_time": created_time,
|
|
|
- "updated_by": updated_by,
|
|
|
- "updated_time": updated_time
|
|
|
- })
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"文档切分失败 (doc_id: {doc_id}): {e}")
|
|
|
- raise RuntimeError(f"文档切分处理异常: {str(e)}")
|
|
|
+ # 记录文本字段名,供BM25使用
|
|
|
+ if f.get("name") in ["text", "content", "chunk"]:
|
|
|
+ text_field_name = f.get("name")
|
|
|
|
|
|
- # 5. 向量化并插入
|
|
|
- # 处理父块
|
|
|
- if parent_entities:
|
|
|
- try:
|
|
|
- p_texts = [e['text'] for e in parent_entities]
|
|
|
- p_vecs = self.emb.embed_documents(p_texts)
|
|
|
- for e, v in zip(parent_entities, p_vecs): e['dense'] = v
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"父块向量化失败 (Embedding Service): {e}")
|
|
|
- raise RuntimeError(f"Embedding 服务调用失败: {str(e)}")
|
|
|
+ kwargs = {
|
|
|
+ "field_name": f.get("name"),
|
|
|
+ "datatype": dtype,
|
|
|
+ "description": f.get("description", "")
|
|
|
+ }
|
|
|
|
|
|
- try:
|
|
|
- self.client.insert(collection_name=self.parent_collection, data=parent_entities)
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"父块存入 Milvus 失败: {e}")
|
|
|
- raise RuntimeError(f"向量数据库写入失败(Parent): {str(e)}")
|
|
|
+ if f.get("is_primary"):
|
|
|
+ kwargs["is_primary"] = True
|
|
|
+ kwargs["auto_id"] = True # 假设主键都是自增
|
|
|
|
|
|
- # 处理子块
|
|
|
- if child_entities:
|
|
|
- try:
|
|
|
- c_texts = [e['text'] for e in child_entities]
|
|
|
- c_vecs = self.emb.embed_documents(c_texts)
|
|
|
- for e, v in zip(child_entities, c_vecs): e['dense'] = v
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"子块向量化失败 (Embedding Service): {e}")
|
|
|
- raise RuntimeError(f"Embedding 服务调用失败: {str(e)}")
|
|
|
+ if dtype == DataType.VARCHAR:
|
|
|
+ kwargs["max_length"] = f.get("max_length", 65535)
|
|
|
+ # 关键修复:如果要被 BM25 引用,必须启用 analyzer
|
|
|
+ if f.get("name") in ["text", "content", "chunk"]:
|
|
|
+ kwargs["enable_analyzer"] = True
|
|
|
+
|
|
|
+ if dtype == DataType.FLOAT_VECTOR:
|
|
|
+ kwargs["dim"] = dimension # 使用传入的 dimension
|
|
|
|
|
|
+ schema.add_field(**kwargs)
|
|
|
+
|
|
|
+ # 如果是 BM25 类型,记录下来以便后续添加 Function
|
|
|
+ if field_type_str == "BM25":
|
|
|
+ bm25_field = f.get("name")
|
|
|
+
|
|
|
+ # 处理 BM25 Function
|
|
|
+ if bm25_field:
|
|
|
try:
|
|
|
- self.client.insert(collection_name=self.child_collection, data=child_entities)
|
|
|
+ schema.add_function(Function(
|
|
|
+ name="bm25_fn",
|
|
|
+ input_field_names=[text_field_name],
|
|
|
+ output_field_names=[bm25_field],
|
|
|
+ function_type=FunctionType.BM25
|
|
|
+ ))
|
|
|
+ logger.info(f"Added BM25 function mapping {text_field_name} -> {bm25_field}")
|
|
|
except Exception as e:
|
|
|
- logger.error(f"子块存入 Milvus 失败: {e}")
|
|
|
- raise RuntimeError(f"向量数据库写入失败(Child): {str(e)}")
|
|
|
+ logger.error(f"Failed to add BM25 function: {e}")
|
|
|
|
|
|
- logger.info(f"Successfully entered knowledge base for doc_id: {doc_id}, parents: {len(parent_entities)}, children: {len(child_entities)}")
|
|
|
- return len(parent_entities), len(child_entities)
|
|
|
+ # 4. 准备索引参数
|
|
|
+ index_params = self.client.prepare_index_params()
|
|
|
+
|
|
|
+ # 5. 为所有向量字段添加索引
|
|
|
+ for f in fields:
|
|
|
+ ftype = f.get("type", "").upper()
|
|
|
+ if ftype == "FLOAT_VECTOR":
|
|
|
+ index_params.add_index(
|
|
|
+ field_name=f.get("name"),
|
|
|
+ index_type="AUTOINDEX",
|
|
|
+ metric_type="COSINE"
|
|
|
+ )
|
|
|
+ elif ftype == "BM25" or ftype == "SPARSE_FLOAT_VECTOR":
|
|
|
+ index_params.add_index(
|
|
|
+ field_name=f.get("name"),
|
|
|
+ index_type="SPARSE_INVERTED_INDEX", # 稀疏向量索引
|
|
|
+ metric_type="BM25"
|
|
|
+ )
|
|
|
|
|
|
+ # 6. 为自定义标量字段添加索引
|
|
|
+ for f in fields:
|
|
|
+ if f.get("type", "").upper() in ["VARCHAR", "INT64", "INT32", "BOOL"] and not f.get("is_primary"):
|
|
|
+ # 排除主键,主键自动索引
|
|
|
+ index_params.add_index(
|
|
|
+ field_name=f.get("name"),
|
|
|
+ index_type="INVERTED"
|
|
|
+ )
|
|
|
+
|
|
|
+ # 7. 创建集合
|
|
|
+ self.client.create_collection(
|
|
|
+ collection_name=name,
|
|
|
+ schema=schema,
|
|
|
+ index_params=index_params
|
|
|
+ )
|
|
|
+
|
|
|
+ else:
|
|
|
+ # 使用简化的 create_collection API
|
|
|
+ self.client.create_collection(
|
|
|
+ collection_name=name,
|
|
|
+ dimension=dimension,
|
|
|
+ description=description,
|
|
|
+ auto_id=True, # 自动生成 ID
|
|
|
+ id_type="int", # ID 类型
|
|
|
+ metric_type="COSINE" # 默认使用余弦相似度
|
|
|
+ )
|
|
|
+
|
|
|
+ logger.info(f"Created collection {name} with dimension {dimension}")
|
|
|
+
|
|
|
+ def drop_collection(self, name: str) -> None:
|
|
|
+ """删除 Milvus 集合"""
|
|
|
+ if self.client.has_collection(name):
|
|
|
+ self.client.drop_collection(name)
|
|
|
+ logger.info(f"Dropped collection {name}")
|
|
|
+
|
|
|
+ def has_collection(self, name: str) -> bool:
|
|
|
+ """检查集合是否存在"""
|
|
|
+ return self.client.has_collection(name)
|
|
|
+
|
|
|
+ def get_collection_details(self) -> List[Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 获取所有 Collections 详细信息
|
|
|
+ """
|
|
|
+ details: List[Dict[str, Any]] = []
|
|
|
+
|
|
|
+ names = self.client.list_collections()
|
|
|
+
|
|
|
+ for name in names:
|
|
|
+ desc = self.client.describe_collection(collection_name=name)
|
|
|
+ stats = self.client.get_collection_stats(collection_name=name)
|
|
|
+ load_state = self.client.get_load_state(collection_name=name)
|
|
|
+
|
|
|
+ # ===== 时间戳转换(按你指定写法,无封装)=====
|
|
|
+ created_time = None
|
|
|
+ updated_time = None
|
|
|
+
|
|
|
+ if desc.get("created_timestamp") is not None:
|
|
|
+ ts_int = int(desc["created_timestamp"])
|
|
|
+ physical_ms = ts_int >> 18
|
|
|
+ created_time = datetime.fromtimestamp(physical_ms / 1000).strftime("%Y-%m-%d %H:%M:%S")
|
|
|
+
|
|
|
+ if desc.get("update_timestamp") is not None:
|
|
|
+ ts_int = int(desc["update_timestamp"])
|
|
|
+ physical_ms = ts_int >> 18
|
|
|
+ updated_time = datetime.fromtimestamp(physical_ms / 1000).strftime("%Y-%m-%d %H:%M:%S")
|
|
|
+
|
|
|
+ # ===== 数量:不保底(要求返回结构必须有 row_count)=====
|
|
|
+ entity_count = stats["row_count"]
|
|
|
+
|
|
|
+ # ===== 状态:不保底(要求返回结构必须有 state)=====
|
|
|
+ status = load_state["state"]
|
|
|
+
|
|
|
+ details.append(
|
|
|
+ {
|
|
|
+ "name": name,
|
|
|
+ "status": status,
|
|
|
+ "entity_count": entity_count,
|
|
|
+ "description": desc.get("description", ""),
|
|
|
+ "created_time": created_time,
|
|
|
+ "updated_time": updated_time,
|
|
|
+ }
|
|
|
+ )
|
|
|
+
|
|
|
+ logger.info(f"成功获取Collections详细信息,共{len(details)}个")
|
|
|
+ return details
|
|
|
+
|
|
|
+ def set_collection_state(self, name: str, action: str) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 改变指定 Collection 的加载状态。
|
|
|
+
|
|
|
+ 参数:
|
|
|
+ - name: 集合名称
|
|
|
+ - action: 操作,取值 'load' 或 'release'
|
|
|
+
|
|
|
+ 返回:
|
|
|
+ - 包含集合名称和当前状态的字典,例如: {"name": name, "state": "Loaded"}
|
|
|
+ """
|
|
|
+ action_norm = (action or "").strip().lower()
|
|
|
+ if action_norm not in {"load", "release"}:
|
|
|
+ raise ValueError("action 必须为 'load' 或 'release'")
|
|
|
+
|
|
|
+ # 执行加载/释放
|
|
|
+ if action_norm == "load":
|
|
|
+ self.client.load_collection(collection_name=name)
|
|
|
+ else:
|
|
|
+ self.client.release_collection(collection_name=name)
|
|
|
+
|
|
|
+ # 返回最新状态
|
|
|
+ load_state = self.client.get_load_state(collection_name=name)
|
|
|
+ state = load_state.get("state") if isinstance(load_state, dict) else load_state
|
|
|
+ result = {"name": name, "state": state, "action": action_norm}
|
|
|
+ logger.info(f"集合 {name} 状态更新为 {state} (action={action_norm})")
|
|
|
+ return result
|
|
|
+
|
|
|
+ def delete_collection_if_empty(self, name: str) -> Dict[str, Any]:
|
|
|
+ """仅当集合内容为空时删除集合,否则抛出异常"""
|
|
|
+ stats = self.client.get_collection_stats(collection_name=name)
|
|
|
+ row_count = stats.get("row_count") if isinstance(stats, dict) else None
|
|
|
+ if row_count is None:
|
|
|
+ raise ValueError("无法获取集合行数,禁止删除")
|
|
|
+ if int(row_count) > 0:
|
|
|
+ raise ValueError("集合内容不为空,不能删除")
|
|
|
+
|
|
|
+ self.client.drop_collection(collection_name=name)
|
|
|
+ logger.info(f"集合 {name} 已删除")
|
|
|
+ return {"name": name, "deleted": True}
|
|
|
+
|
|
|
+ def get_collection_detail(self, name: str) -> Dict[str, Any]:
|
|
|
+ """获取单个集合的详细信息,包含schema、索引等所有desc字段"""
|
|
|
+ desc = self.client.describe_collection(collection_name=name)
|
|
|
+ stats = self.client.get_collection_stats(collection_name=name)
|
|
|
+ load_state = self.client.get_load_state(collection_name=name)
|
|
|
+
|
|
|
+ # 时间戳转换
|
|
|
+ created_time = None
|
|
|
+ updated_time = None
|
|
|
+
|
|
|
+ if desc.get("created_timestamp") is not None:
|
|
|
+ ts_int = int(desc["created_timestamp"])
|
|
|
+ physical_ms = ts_int >> 18
|
|
|
+ created_time = datetime.fromtimestamp(physical_ms / 1000).strftime("%Y-%m-%d %H:%M:%S")
|
|
|
+
|
|
|
+ if desc.get("update_timestamp") is not None:
|
|
|
+ ts_int = int(desc["update_timestamp"])
|
|
|
+ physical_ms = ts_int >> 18
|
|
|
+ updated_time = datetime.fromtimestamp(physical_ms / 1000).strftime("%Y-%m-%d %H:%M:%S")
|
|
|
+
|
|
|
+ entity_count = stats.get("row_count", 0)
|
|
|
+ status = load_state.get("state") if isinstance(load_state, dict) else load_state
|
|
|
+
|
|
|
+ # 提取字段schema
|
|
|
+ fields = []
|
|
|
+ if "fields" in desc:
|
|
|
+ for field in desc["fields"]:
|
|
|
+ field_info = {
|
|
|
+ "name": field.get("name"),
|
|
|
+ "type": str(field.get("type")),
|
|
|
+ "description": field.get("description", ""),
|
|
|
+ "is_primary": field.get("is_primary", False),
|
|
|
+ "auto_id": field.get("auto_id"),
|
|
|
+ }
|
|
|
+ # 向量维度
|
|
|
+ if "params" in field and "dim" in field["params"]:
|
|
|
+ field_info["dim"] = field["params"]["dim"]
|
|
|
+ # 字符串长度
|
|
|
+ if "params" in field and "max_length" in field["params"]:
|
|
|
+ field_info["max_length"] = field["params"]["max_length"]
|
|
|
+ # 其他params
|
|
|
+ if "params" in field:
|
|
|
+ field_info["params"] = field["params"]
|
|
|
+ fields.append(field_info)
|
|
|
+
|
|
|
+ # 提取索引信息
|
|
|
+ indices = []
|
|
|
+
|
|
|
+ # 尝试从 describe_collection 结果中获取 (兼容旧逻辑)
|
|
|
+ if "indexes" in desc:
|
|
|
+ for idx in desc["indexes"]:
|
|
|
+ index_info = {
|
|
|
+ "field_name": idx.get("field_name"),
|
|
|
+ "index_name": idx.get("index_name"),
|
|
|
+ "index_type": idx.get("index_type"),
|
|
|
+ "metric_type": idx.get("metric_type"),
|
|
|
+ "params": idx.get("params"),
|
|
|
+ }
|
|
|
+ indices.append(index_info)
|
|
|
+
|
|
|
+ # 如果没有获取到索引信息,尝试主动查询 list_indexes
|
|
|
+ if not indices:
|
|
|
+ try:
|
|
|
+ # 获取索引列表 (通常返回索引名称列表)
|
|
|
+ index_names = self.client.list_indexes(collection_name=name)
|
|
|
+ if index_names:
|
|
|
+ for idx_name in index_names:
|
|
|
+ try:
|
|
|
+ # 获取索引详情
|
|
|
+ idx_desc = self.client.describe_index(collection_name=name, index_name=idx_name)
|
|
|
+ if idx_desc:
|
|
|
+ indices.append({
|
|
|
+ "field_name": idx_desc.get("field_name"),
|
|
|
+ "index_name": idx_desc.get("index_name"),
|
|
|
+ "index_type": idx_desc.get("index_type"),
|
|
|
+ "metric_type": idx_desc.get("metric_type"),
|
|
|
+ "params": idx_desc.get("params"),
|
|
|
+ })
|
|
|
+ except Exception:
|
|
|
+ continue
|
|
|
+ except Exception as e:
|
|
|
+ logger.warning(f"Failed to list/describe indexes for {name}: {e}")
|
|
|
+
|
|
|
+ detail = {
|
|
|
+ "name": name,
|
|
|
+ "description": desc.get("description", ""),
|
|
|
+ "status": status,
|
|
|
+ "entity_count": entity_count,
|
|
|
+ "created_time": created_time,
|
|
|
+ "updated_time": updated_time,
|
|
|
+ "fields": fields,
|
|
|
+ "enable_dynamic_field": desc.get("enable_dynamic_field", False),
|
|
|
+ "consistency_level": desc.get("consistency_level"),
|
|
|
+ "num_shards": desc.get("num_shards"),
|
|
|
+ "num_partitions": desc.get("num_partitions"),
|
|
|
+ "indices": indices,
|
|
|
+ "properties": desc.get("properties"),
|
|
|
+ "aliases": desc.get("aliases", []),
|
|
|
+ }
|
|
|
+
|
|
|
+ logger.info(f"成功获取集合 {name} 的详细信息")
|
|
|
+ return detail
|
|
|
+
|
|
|
+
|
|
|
+ def update_collection_description(self, name: str, description: str) -> Dict[str, Any]:
|
|
|
+ """使用 alter_collection_properties 更新集合描述"""
|
|
|
+ description = description or ""
|
|
|
+
|
|
|
+ # 1. 更新集合 description(唯一修改点)
|
|
|
+ self.client.alter_collection_properties(
|
|
|
+ collection_name=name,
|
|
|
+ properties={"collection.description": description},
|
|
|
+ )
|
|
|
+
|
|
|
+ # 2. 重新获取集合信息
|
|
|
+ desc = self.client.describe_collection(collection_name=name)
|
|
|
+ print(desc)
|
|
|
+ stats = self.client.get_collection_stats(collection_name=name)
|
|
|
+ load_state = self.client.get_load_state(collection_name=name)
|
|
|
+
|
|
|
+ # 3. 时间戳转换(Milvus TSO -> 物理时间)
|
|
|
+ def ts_to_str(ts):
|
|
|
+ if ts is None:
|
|
|
+ return None
|
|
|
+ ts_int = int(ts)
|
|
|
+ physical_ms = ts_int >> 18
|
|
|
+ return datetime.fromtimestamp(physical_ms / 1000).strftime("%Y-%m-%d %H:%M:%S")
|
|
|
+
|
|
|
+ created_time = ts_to_str(desc.get("created_timestamp"))
|
|
|
+ updated_time = ts_to_str(desc.get("update_timestamp"))
|
|
|
+
|
|
|
+ entity_count = stats.get("row_count") if isinstance(stats, dict) else None
|
|
|
+ status = load_state.get("state") if isinstance(load_state, dict) else load_state
|
|
|
+
|
|
|
+ return {
|
|
|
+ "name": name,
|
|
|
+ "status": status,
|
|
|
+ "entity_count": entity_count,
|
|
|
+ "description": desc.get("description", ""),
|
|
|
+ "created_time": created_time,
|
|
|
+ "updated_time": updated_time,
|
|
|
+ }
|
|
|
+
|
|
|
+ def hybrid_search(self, collection_name: str, query_text: str,
|
|
|
+ top_k: int = 3, ranker_type: str = "weighted",
|
|
|
+ dense_weight: float = 0.7, sparse_weight: float = 0.3):
|
|
|
+ """
|
|
|
+ 混合搜索(参考 test_hybrid_v2.6.py 的实现)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ param: 包含collection_name的参数字典
|
|
|
+ query_text: 查询文本
|
|
|
+ top_k: 返回结果数量
|
|
|
+ ranker_type: 重排序类型 "weighted" 或 "rrf"
|
|
|
+ dense_weight: 密集向量权重(当ranker_type="weighted"时使用)
|
|
|
+ sparse_weight: 稀疏向量权重(当ranker_type="weighted"时使用)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ List[Dict]: 搜索结果列表
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ collection_name = collection_name
|
|
|
+
|
|
|
+ # 获取 vectorstore 实例(包含 Milvus 和 BM25BuiltInFunction)
|
|
|
+ vectorstore = get_milvus_vectorstore(
|
|
|
+ collection_name=collection_name,
|
|
|
+ consistency_level="Strong"
|
|
|
+ )
|
|
|
+
|
|
|
+ # 执行混合搜索 (完全按照 test_hybrid_v2.6.py 的逻辑)
|
|
|
+ if ranker_type == "weighted":
|
|
|
+ results = vectorstore.similarity_search(
|
|
|
+ query=query_text,
|
|
|
+ k=top_k,
|
|
|
+ ranker_type="weighted",
|
|
|
+ ranker_params={"weights": [dense_weight, sparse_weight]}
|
|
|
+ )
|
|
|
+ else: # rrf
|
|
|
+ results = vectorstore.similarity_search(
|
|
|
+ query=query_text,
|
|
|
+ k=top_k,
|
|
|
+ ranker_type="rrf",
|
|
|
+ ranker_params={"k": 60}
|
|
|
+ )
|
|
|
+
|
|
|
+ # 格式化结果,保持与其他搜索方法一致
|
|
|
+ formatted_results = []
|
|
|
+ for doc in results:
|
|
|
+ formatted_results.append({
|
|
|
+ 'id': doc.metadata.get('pk', 0),
|
|
|
+ 'text_content': doc.page_content,
|
|
|
+ 'metadata': doc.metadata,
|
|
|
+ 'distance': 0.0,
|
|
|
+ 'similarity': 1.0
|
|
|
+ })
|
|
|
+
|
|
|
+ logger.info(f"Hybrid search returned {len(formatted_results)} results")
|
|
|
+ return formatted_results
|
|
|
+
|
|
|
except Exception as e:
|
|
|
- # 重新抛出已处理的异常或包装未处理的异常
|
|
|
- if not isinstance(e, RuntimeError):
|
|
|
- logger.exception(f"入库流程发生未知异常 (doc_id: {doc_id})")
|
|
|
- raise RuntimeError(f"入库未知错误: {str(e)}")
|
|
|
- raise e
|
|
|
-
|
|
|
-# 全局 Milvus 服务实例
|
|
|
-milvus_host = config_handler.get("admin_app", "MILVUS_HOST", "192.168.92.61")
|
|
|
-milvus_port = config_handler.get("admin_app", "MILVUS_PORT", "19530")
|
|
|
-milvus_service = MilvusService(
|
|
|
- uri=f"http://{milvus_host}:{milvus_port}",
|
|
|
- db_name=config_handler.get("admin_app", "MILVUS_DB", "lq_db"),
|
|
|
- parent_collection=config_handler.get("admin_app", "PARENT_COLLECTION_NAME", "test_27_parent"),
|
|
|
- child_collection=config_handler.get("admin_app", "CHILD_COLLECTION_NAME", "test_27_child")
|
|
|
-)
|
|
|
+ logger.error(f"Error in hybrid search: {e}")
|
|
|
+ # 回退到传统的向量搜索
|
|
|
+ logger.info("Falling back to traditional vector search")
|
|
|
+
|
|
|
+
|
|
|
+# 可选:单例
|
|
|
+milvus_service = MilvusService()
|
|
|
+
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ # 推荐这样跑:
|
|
|
+ # uv run python -m src.app.services.milvus_service
|
|
|
+ import json
|
|
|
+
|
|
|
+ service = MilvusService()
|
|
|
+
|
|
|
+ # 测试混合搜索 hybrid_search
|
|
|
+ print("=" * 50)
|
|
|
+ print("测试混合检索 (Hybrid Search)")
|
|
|
+ print("=" * 50)
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 示例参数,需要根据实际情况修改
|
|
|
+ collection_name = "first_bfp_collection_status"
|
|
|
+ query_text = "《公路水运工程临时用电技术规程》(JTT1499-2024)状态为现行" # 修改为实际查询内容
|
|
|
+
|
|
|
+ # 测试 weighted 模式
|
|
|
+ print("\n1. 测试 Weighted 重排序模式:")
|
|
|
+ print(f" 集合: {collection_name}")
|
|
|
+ print(f" 查询: {query_text}")
|
|
|
+ print(f" 密集权重: 0.7, 稀疏权重: 0.3")
|
|
|
+
|
|
|
+ results_weighted = service.hybrid_search(
|
|
|
+ collection_name=collection_name,
|
|
|
+ query_text=query_text,
|
|
|
+ top_k=5,
|
|
|
+ ranker_type="weighted",
|
|
|
+ dense_weight=0.7,
|
|
|
+ sparse_weight=0.3
|
|
|
+ )
|
|
|
+
|
|
|
+ print(f"\n 结果数量: {len(results_weighted)}")
|
|
|
+ for i, result in enumerate(results_weighted, 1):
|
|
|
+ print(f" [{i}] ID: {result.get('id')}, Text: {result.get('text_content')[:50]}...")
|
|
|
+
|
|
|
+ # 测试 RRF 模式
|
|
|
+ print("\n2. 测试 RRF (Reciprocal Rank Fusion) 重排序模式:")
|
|
|
+ print(f" 集合: {collection_name}")
|
|
|
+ print(f" 查询: {query_text}")
|
|
|
+
|
|
|
+ results_rrf = service.hybrid_search(
|
|
|
+ collection_name=collection_name,
|
|
|
+ query_text=query_text,
|
|
|
+ top_k=5,
|
|
|
+ ranker_type="rrf"
|
|
|
+ )
|
|
|
+
|
|
|
+ print(f"\n 结果数量: {len(results_rrf)}")
|
|
|
+ for i, result in enumerate(results_rrf, 1):
|
|
|
+ print(f" [{i}] ID: {result.get('id')}, Text: {result.get('text_content')[:50]}...")
|
|
|
+
|
|
|
+ print("\n✓ 混合检索测试完成")
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"\n✗ 混合检索测试失败: {e}")
|
|
|
+ import traceback
|
|
|
+ traceback.print_exc()
|
|
|
+
|
|
|
+ # 也可以查看集合详情
|
|
|
+ print("\n" + "=" * 50)
|
|
|
+ print("获取所有集合信息:")
|
|
|
+ print("=" * 50)
|
|
|
+ data = service.get_collection_details()
|
|
|
+ for item in data:
|
|
|
+ print(json.dumps(item, ensure_ascii=False, indent=2))
|