| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305 |
- import time
- import re
- import hashlib
- import logging
- import json
- from typing import List, Dict, Any, Tuple, Optional
- from langchain_core.documents import Document
- from langchain_openai import OpenAIEmbeddings
- from pymilvus import MilvusClient, DataType, Function, FunctionType
- from app.core.config import config_handler
- logger = logging.getLogger(__name__)
- class MilvusService:
- """Milvus 向量库服务类,实现父子块切分与混合检索存储"""
-
- def __init__(self, uri: str, db_name: str, parent_collection: str, child_collection: str):
- self.client = MilvusClient(uri=uri, db_name=db_name)
- self.parent_collection = parent_collection
- self.child_collection = child_collection
- self.emb = self._get_embeddings()
-
- # 配置参数
- self.PARENT_MAX_CHARS = 6000
- self.DENSE_DIM = 4096
- self.H1_RE = re.compile(r"^#\s+(.+?)\s*$", re.MULTILINE)
- self.BLANK_SPLIT_RE = re.compile(r"\n\s*\n+")
-
- # 确保集合已创建
- self.ensure_collections()
- def has_collection(self, collection_name: str) -> bool:
- """检查集合是否存在"""
- return self.client.has_collection(collection_name=collection_name)
- def _get_embeddings(self) -> OpenAIEmbeddings:
- """获取 Embedding 模型配置"""
- return OpenAIEmbeddings(
- base_url=config_handler.get("admin_app", "EMBEDDING_BASE_URL", "http://192.168.91.253:9003/v1"),
- model=config_handler.get("admin_app", "EMBEDDING_MODEL", "Qwen3-Embedding-8B"),
- api_key=config_handler.get("admin_app", "EMBEDDING_API_KEY", "dummy"),
- )
- # --- 切分工具方法 ---
-
- def _split_md_by_blank_lines(self, md: str) -> List[str]:
- md = md.replace("\r\n", "\n").replace("\r", "\n")
- parts = self.BLANK_SPLIT_RE.split(md)
- return [p.strip() for p in parts if p.strip()]
- def _is_heading_chunk(self, chunk: str) -> Optional[Tuple[int, str]]:
- first_line = chunk.split("\n", 1)[0].strip()
- m = re.match(r"^(#{1,6})\s+(.+?)\s*$", first_line)
- if not m:
- return None
- return len(m.group(1)), m.group(2).strip()
- def _split_md_by_h1_sections(self, md: str) -> List[Tuple[str, str]]:
- """按一级标题切分父块"""
- md = md.replace("\r\n", "\n").replace("\r", "\n")
- matches = list(self.H1_RE.finditer(md))
- if not matches:
- txt = md.strip()
- return [("__NO_H1__", txt)] if txt else []
- sections = []
- # 检查第一个#之前的内容
- first_match_start = matches[0].start()
- preamble = md[:first_match_start].strip()
- if preamble:
- sections.append(("__PREAMBLE__", preamble))
-
- for i, m in enumerate(matches):
- title = m.group(1).strip()
- start = m.start()
- end = matches[i + 1].start() if i + 1 < len(matches) else len(md)
- sec = md[start:end].strip()
- if sec:
- sections.append((title, sec))
- return sections
- def _make_parent_id(self, doc_id: str, doc_version: int, doc_name: str, h1_title: str, parent_seq: int) -> str:
- """生成稳定的 parent_id"""
- raw = f"{doc_id}|{doc_version}|{doc_name}|{parent_seq}|{h1_title}".encode("utf-8")
- return hashlib.sha1(raw).hexdigest()
- def _split_text_by_max_chars(self, text: str, max_chars: int) -> List[str]:
- """父段过长时切片"""
- text = (text or "").strip()
- if not text or len(text) <= max_chars:
- return [text] if text else []
- chunks = self._split_md_by_blank_lines(text)
- result = []
- current_slice = ""
-
- for chunk in chunks:
- if len(chunk) > max_chars:
- if current_slice.strip():
- result.append(current_slice.strip())
- current_slice = ""
- start = 0
- while start < len(chunk):
- result.append(chunk[start : start + max_chars].strip())
- start += max_chars
- else:
- test_slice = current_slice + "\n\n" + chunk if current_slice else chunk
- if len(test_slice) <= max_chars:
- current_slice = test_slice
- else:
- if current_slice.strip():
- result.append(current_slice.strip())
- current_slice = chunk
-
- if current_slice.strip():
- result.append(current_slice.strip())
- return [s for s in result if s]
- # --- 核心业务逻辑 ---
- def ensure_collections(self):
- """确保父子 Collection 已创建并配置索引"""
- for col_name in [self.parent_collection, self.child_collection]:
- if not self.client.has_collection(collection_name=col_name):
- schema = self.client.create_schema(auto_id=True, enable_dynamic_fields=False)
- schema.add_field("pk", DataType.INT64, is_primary=True, auto_id=True)
- schema.add_field("text", DataType.VARCHAR, max_length=65535, enable_analyzer=True)
- schema.add_field("dense", DataType.FLOAT_VECTOR, dim=self.DENSE_DIM)
- schema.add_field("sparse", DataType.SPARSE_FLOAT_VECTOR)
- schema.add_field("document_id", DataType.VARCHAR, max_length=256)
- schema.add_field("parent_id", DataType.VARCHAR, max_length=256)
- schema.add_field("index", DataType.INT64)
- schema.add_field("tag_list", DataType.VARCHAR, max_length=2048)
- schema.add_field("permission", DataType.JSON)
- schema.add_field("metadata", DataType.JSON)
- schema.add_field("is_deleted", DataType.INT64)
- schema.add_field("created_by", DataType.VARCHAR, max_length=256)
- schema.add_field("created_time", DataType.INT64)
- schema.add_field("updated_by", DataType.VARCHAR, max_length=256)
- schema.add_field("updated_time", DataType.INT64)
- schema.add_function(Function(
- name="bm25_fn",
- input_field_names=["text"],
- output_field_names=["sparse"],
- function_type=FunctionType.BM25,
- ))
-
- self.client.create_collection(collection_name=col_name, schema=schema)
-
- index_params = self.client.prepare_index_params()
- index_params.add_index(field_name="dense", index_name="dense_idx", index_type="AUTOINDEX", metric_type="COSINE")
- index_params.add_index(field_name="sparse", index_name="bm25_idx", index_type="SPARSE_INVERTED_INDEX", metric_type="BM25", params={"inverted_index_algo": "DAAT_MAXSCORE"})
- self.client.create_index(collection_name=col_name, index_params=index_params)
-
- self.client.load_collection(collection_name=col_name)
- async def insert_knowledge(self, md_text: str, doc_info: Dict[str, Any]):
- """执行切分、向量化并存入 Milvus"""
- doc_id = doc_info['doc_id']
- doc_name = doc_info.get('doc_name', 'unknown')
- doc_version = doc_info.get('doc_version', 20260127)
- tag_list = str(doc_info.get('tags') or '')
-
- # 公共字段准备
- created_by = doc_info.get('created_by', 'system')
- created_time = doc_info.get('created_time', int(time.time() * 1000))
- updated_by = doc_info.get('updated_by', 'system')
- updated_time = doc_info.get('updated_time', int(time.time() * 1000))
- permission = doc_info.get('permission', {})
- try:
- # 1. 幂等处理:清理旧数据
- try:
- self.client.delete(collection_name=self.parent_collection, filter=f"document_id == '{doc_id}'")
- self.client.delete(collection_name=self.child_collection, filter=f"document_id == '{doc_id}'")
- except Exception as e:
- logger.warning(f"清理旧数据失败 (doc_id: {doc_id}): {e}")
- # 继续执行,可能是第一次入库
- # 2. 切分父子块
- try:
- parent_sections = self._split_md_by_h1_sections(md_text)
- parent_entities = []
- child_entities = []
-
- # 预生成所有 parent_id
- parent_seq_to_id = {}
- for seq, (title, _) in enumerate(parent_sections):
- parent_seq_to_id[seq] = self._make_parent_id(doc_id, doc_version, doc_name, title, seq)
- # 3. 处理子块
- for seq, (h1_title, sec_text) in enumerate(parent_sections):
- p_id = parent_seq_to_id[seq]
- chunks = self._split_md_by_blank_lines(sec_text)
- heading_path = []
-
- for c_idx, chunk in enumerate(chunks):
- h_info = self._is_heading_chunk(chunk)
- if h_info:
- level, title = h_info
- heading_path = heading_path[:level-1] + [title]
-
- outline_path = " > ".join(heading_path)
-
- child_entities.append({
- "text": chunk,
- "is_deleted": 0,
- "parent_id": p_id,
- "document_id": doc_id,
- "index": int(c_idx),
- "tag_list": tag_list,
- "permission": permission,
- "metadata": {
- "doc_name": doc_name,
- "outline_path": outline_path,
- "doc_version": doc_version
- },
- "created_by": created_by,
- "created_time": created_time,
- "updated_by": updated_by,
- "updated_time": updated_time
- })
- # 4. 处理父块
- for seq, (h1_title, sec_text) in enumerate(parent_sections):
- p_id = parent_seq_to_id[seq]
- slices = self._split_text_by_max_chars(sec_text, self.PARENT_MAX_CHARS)
- for s_idx, slice_text in enumerate(slices):
- parent_entities.append({
- "text": slice_text,
- "is_deleted": 0,
- "parent_id": p_id,
- "document_id": doc_id,
- "index": int(seq),
- "tag_list": tag_list,
- "permission": permission,
- "metadata": {
- "doc_name": doc_name,
- "outline_path": h1_title if h1_title not in ["__PREAMBLE__", "__NO_H1__"] else doc_name,
- "doc_version": doc_version
- },
- "created_by": created_by,
- "created_time": created_time,
- "updated_by": updated_by,
- "updated_time": updated_time
- })
- except Exception as e:
- logger.error(f"文档切分失败 (doc_id: {doc_id}): {e}")
- raise RuntimeError(f"文档切分处理异常: {str(e)}")
- # 5. 向量化并插入
- # 处理父块
- if parent_entities:
- try:
- p_texts = [e['text'] for e in parent_entities]
- p_vecs = self.emb.embed_documents(p_texts)
- for e, v in zip(parent_entities, p_vecs): e['dense'] = v
- except Exception as e:
- logger.error(f"父块向量化失败 (Embedding Service): {e}")
- raise RuntimeError(f"Embedding 服务调用失败: {str(e)}")
-
- try:
- self.client.insert(collection_name=self.parent_collection, data=parent_entities)
- except Exception as e:
- logger.error(f"父块存入 Milvus 失败: {e}")
- raise RuntimeError(f"向量数据库写入失败(Parent): {str(e)}")
-
- # 处理子块
- if child_entities:
- try:
- c_texts = [e['text'] for e in child_entities]
- c_vecs = self.emb.embed_documents(c_texts)
- for e, v in zip(child_entities, c_vecs): e['dense'] = v
- except Exception as e:
- logger.error(f"子块向量化失败 (Embedding Service): {e}")
- raise RuntimeError(f"Embedding 服务调用失败: {str(e)}")
-
- try:
- self.client.insert(collection_name=self.child_collection, data=child_entities)
- except Exception as e:
- logger.error(f"子块存入 Milvus 失败: {e}")
- raise RuntimeError(f"向量数据库写入失败(Child): {str(e)}")
- logger.info(f"Successfully entered knowledge base for doc_id: {doc_id}, parents: {len(parent_entities)}, children: {len(child_entities)}")
- return len(parent_entities), len(child_entities)
-
- except Exception as e:
- # 重新抛出已处理的异常或包装未处理的异常
- if not isinstance(e, RuntimeError):
- logger.exception(f"入库流程发生未知异常 (doc_id: {doc_id})")
- raise RuntimeError(f"入库未知错误: {str(e)}")
- raise e
- # 全局 Milvus 服务实例
- milvus_host = config_handler.get("admin_app", "MILVUS_HOST", "192.168.92.61")
- milvus_port = config_handler.get("admin_app", "MILVUS_PORT", "19530")
- milvus_service = MilvusService(
- uri=f"http://{milvus_host}:{milvus_port}",
- db_name=config_handler.get("admin_app", "MILVUS_DB", "lq_db"),
- parent_collection=config_handler.get("admin_app", "PARENT_COLLECTION_NAME", "test_27_parent"),
- child_collection=config_handler.get("admin_app", "CHILD_COLLECTION_NAME", "test_27_child")
- )
|