milvus_service.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. import time
  2. import re
  3. import hashlib
  4. import logging
  5. import json
  6. from typing import List, Dict, Any, Tuple, Optional
  7. from langchain_core.documents import Document
  8. from langchain_openai import OpenAIEmbeddings
  9. from pymilvus import MilvusClient, DataType, Function, FunctionType
  10. from app.core.config import config_handler
  11. logger = logging.getLogger(__name__)
  12. class MilvusService:
  13. """Milvus 向量库服务类,实现父子块切分与混合检索存储"""
  14. def __init__(self, uri: str, db_name: str, parent_collection: str, child_collection: str):
  15. self.client = MilvusClient(uri=uri, db_name=db_name)
  16. self.parent_collection = parent_collection
  17. self.child_collection = child_collection
  18. self.emb = self._get_embeddings()
  19. # 配置参数
  20. self.PARENT_MAX_CHARS = 6000
  21. self.DENSE_DIM = 4096
  22. self.H1_RE = re.compile(r"^#\s+(.+?)\s*$", re.MULTILINE)
  23. self.BLANK_SPLIT_RE = re.compile(r"\n\s*\n+")
  24. # 确保集合已创建
  25. self.ensure_collections()
  26. def has_collection(self, collection_name: str) -> bool:
  27. """检查集合是否存在"""
  28. return self.client.has_collection(collection_name=collection_name)
  29. def _get_embeddings(self) -> OpenAIEmbeddings:
  30. """获取 Embedding 模型配置"""
  31. return OpenAIEmbeddings(
  32. base_url=config_handler.get("admin_app", "EMBEDDING_BASE_URL", "http://192.168.91.253:9003/v1"),
  33. model=config_handler.get("admin_app", "EMBEDDING_MODEL", "Qwen3-Embedding-8B"),
  34. api_key=config_handler.get("admin_app", "EMBEDDING_API_KEY", "dummy"),
  35. )
  36. # --- 切分工具方法 ---
  37. def _split_md_by_blank_lines(self, md: str) -> List[str]:
  38. md = md.replace("\r\n", "\n").replace("\r", "\n")
  39. parts = self.BLANK_SPLIT_RE.split(md)
  40. return [p.strip() for p in parts if p.strip()]
  41. def _is_heading_chunk(self, chunk: str) -> Optional[Tuple[int, str]]:
  42. first_line = chunk.split("\n", 1)[0].strip()
  43. m = re.match(r"^(#{1,6})\s+(.+?)\s*$", first_line)
  44. if not m:
  45. return None
  46. return len(m.group(1)), m.group(2).strip()
  47. def _split_md_by_h1_sections(self, md: str) -> List[Tuple[str, str]]:
  48. """按一级标题切分父块"""
  49. md = md.replace("\r\n", "\n").replace("\r", "\n")
  50. matches = list(self.H1_RE.finditer(md))
  51. if not matches:
  52. txt = md.strip()
  53. return [("__NO_H1__", txt)] if txt else []
  54. sections = []
  55. # 检查第一个#之前的内容
  56. first_match_start = matches[0].start()
  57. preamble = md[:first_match_start].strip()
  58. if preamble:
  59. sections.append(("__PREAMBLE__", preamble))
  60. for i, m in enumerate(matches):
  61. title = m.group(1).strip()
  62. start = m.start()
  63. end = matches[i + 1].start() if i + 1 < len(matches) else len(md)
  64. sec = md[start:end].strip()
  65. if sec:
  66. sections.append((title, sec))
  67. return sections
  68. def _make_parent_id(self, doc_id: str, doc_version: int, doc_name: str, h1_title: str, parent_seq: int) -> str:
  69. """生成稳定的 parent_id"""
  70. raw = f"{doc_id}|{doc_version}|{doc_name}|{parent_seq}|{h1_title}".encode("utf-8")
  71. return hashlib.sha1(raw).hexdigest()
  72. def _split_text_by_max_chars(self, text: str, max_chars: int) -> List[str]:
  73. """父段过长时切片"""
  74. text = (text or "").strip()
  75. if not text or len(text) <= max_chars:
  76. return [text] if text else []
  77. chunks = self._split_md_by_blank_lines(text)
  78. result = []
  79. current_slice = ""
  80. for chunk in chunks:
  81. if len(chunk) > max_chars:
  82. if current_slice.strip():
  83. result.append(current_slice.strip())
  84. current_slice = ""
  85. start = 0
  86. while start < len(chunk):
  87. result.append(chunk[start : start + max_chars].strip())
  88. start += max_chars
  89. else:
  90. test_slice = current_slice + "\n\n" + chunk if current_slice else chunk
  91. if len(test_slice) <= max_chars:
  92. current_slice = test_slice
  93. else:
  94. if current_slice.strip():
  95. result.append(current_slice.strip())
  96. current_slice = chunk
  97. if current_slice.strip():
  98. result.append(current_slice.strip())
  99. return [s for s in result if s]
  100. # --- 核心业务逻辑 ---
  101. def ensure_collections(self):
  102. """确保父子 Collection 已创建并配置索引"""
  103. for col_name in [self.parent_collection, self.child_collection]:
  104. if not self.client.has_collection(collection_name=col_name):
  105. schema = self.client.create_schema(auto_id=True, enable_dynamic_fields=False)
  106. schema.add_field("pk", DataType.INT64, is_primary=True, auto_id=True)
  107. schema.add_field("text", DataType.VARCHAR, max_length=65535, enable_analyzer=True)
  108. schema.add_field("dense", DataType.FLOAT_VECTOR, dim=self.DENSE_DIM)
  109. schema.add_field("sparse", DataType.SPARSE_FLOAT_VECTOR)
  110. schema.add_field("document_id", DataType.VARCHAR, max_length=256)
  111. schema.add_field("parent_id", DataType.VARCHAR, max_length=256)
  112. schema.add_field("index", DataType.INT64)
  113. schema.add_field("tag_list", DataType.VARCHAR, max_length=2048)
  114. schema.add_field("permission", DataType.JSON)
  115. schema.add_field("metadata", DataType.JSON)
  116. schema.add_field("is_deleted", DataType.INT64)
  117. schema.add_field("created_by", DataType.VARCHAR, max_length=256)
  118. schema.add_field("created_time", DataType.INT64)
  119. schema.add_field("updated_by", DataType.VARCHAR, max_length=256)
  120. schema.add_field("updated_time", DataType.INT64)
  121. schema.add_function(Function(
  122. name="bm25_fn",
  123. input_field_names=["text"],
  124. output_field_names=["sparse"],
  125. function_type=FunctionType.BM25,
  126. ))
  127. self.client.create_collection(collection_name=col_name, schema=schema)
  128. index_params = self.client.prepare_index_params()
  129. index_params.add_index(field_name="dense", index_name="dense_idx", index_type="AUTOINDEX", metric_type="COSINE")
  130. index_params.add_index(field_name="sparse", index_name="bm25_idx", index_type="SPARSE_INVERTED_INDEX", metric_type="BM25", params={"inverted_index_algo": "DAAT_MAXSCORE"})
  131. self.client.create_index(collection_name=col_name, index_params=index_params)
  132. self.client.load_collection(collection_name=col_name)
  133. async def insert_knowledge(self, md_text: str, doc_info: Dict[str, Any]):
  134. """执行切分、向量化并存入 Milvus"""
  135. doc_id = doc_info['doc_id']
  136. doc_name = doc_info.get('doc_name', 'unknown')
  137. doc_version = doc_info.get('doc_version', 20260127)
  138. tag_list = str(doc_info.get('tags') or '')
  139. # 公共字段准备
  140. created_by = doc_info.get('created_by', 'system')
  141. created_time = doc_info.get('created_time', int(time.time() * 1000))
  142. updated_by = doc_info.get('updated_by', 'system')
  143. updated_time = doc_info.get('updated_time', int(time.time() * 1000))
  144. permission = doc_info.get('permission', {})
  145. try:
  146. # 1. 幂等处理:清理旧数据
  147. try:
  148. self.client.delete(collection_name=self.parent_collection, filter=f"document_id == '{doc_id}'")
  149. self.client.delete(collection_name=self.child_collection, filter=f"document_id == '{doc_id}'")
  150. except Exception as e:
  151. logger.warning(f"清理旧数据失败 (doc_id: {doc_id}): {e}")
  152. # 继续执行,可能是第一次入库
  153. # 2. 切分父子块
  154. try:
  155. parent_sections = self._split_md_by_h1_sections(md_text)
  156. parent_entities = []
  157. child_entities = []
  158. # 预生成所有 parent_id
  159. parent_seq_to_id = {}
  160. for seq, (title, _) in enumerate(parent_sections):
  161. parent_seq_to_id[seq] = self._make_parent_id(doc_id, doc_version, doc_name, title, seq)
  162. # 3. 处理子块
  163. for seq, (h1_title, sec_text) in enumerate(parent_sections):
  164. p_id = parent_seq_to_id[seq]
  165. chunks = self._split_md_by_blank_lines(sec_text)
  166. heading_path = []
  167. for c_idx, chunk in enumerate(chunks):
  168. h_info = self._is_heading_chunk(chunk)
  169. if h_info:
  170. level, title = h_info
  171. heading_path = heading_path[:level-1] + [title]
  172. outline_path = " > ".join(heading_path)
  173. child_entities.append({
  174. "text": chunk,
  175. "is_deleted": 0,
  176. "parent_id": p_id,
  177. "document_id": doc_id,
  178. "index": int(c_idx),
  179. "tag_list": tag_list,
  180. "permission": permission,
  181. "metadata": {
  182. "doc_name": doc_name,
  183. "outline_path": outline_path,
  184. "doc_version": doc_version
  185. },
  186. "created_by": created_by,
  187. "created_time": created_time,
  188. "updated_by": updated_by,
  189. "updated_time": updated_time
  190. })
  191. # 4. 处理父块
  192. for seq, (h1_title, sec_text) in enumerate(parent_sections):
  193. p_id = parent_seq_to_id[seq]
  194. slices = self._split_text_by_max_chars(sec_text, self.PARENT_MAX_CHARS)
  195. for s_idx, slice_text in enumerate(slices):
  196. parent_entities.append({
  197. "text": slice_text,
  198. "is_deleted": 0,
  199. "parent_id": p_id,
  200. "document_id": doc_id,
  201. "index": int(seq),
  202. "tag_list": tag_list,
  203. "permission": permission,
  204. "metadata": {
  205. "doc_name": doc_name,
  206. "outline_path": h1_title if h1_title not in ["__PREAMBLE__", "__NO_H1__"] else doc_name,
  207. "doc_version": doc_version
  208. },
  209. "created_by": created_by,
  210. "created_time": created_time,
  211. "updated_by": updated_by,
  212. "updated_time": updated_time
  213. })
  214. except Exception as e:
  215. logger.error(f"文档切分失败 (doc_id: {doc_id}): {e}")
  216. raise RuntimeError(f"文档切分处理异常: {str(e)}")
  217. # 5. 向量化并插入
  218. # 处理父块
  219. if parent_entities:
  220. try:
  221. p_texts = [e['text'] for e in parent_entities]
  222. p_vecs = self.emb.embed_documents(p_texts)
  223. for e, v in zip(parent_entities, p_vecs): e['dense'] = v
  224. except Exception as e:
  225. logger.error(f"父块向量化失败 (Embedding Service): {e}")
  226. raise RuntimeError(f"Embedding 服务调用失败: {str(e)}")
  227. try:
  228. self.client.insert(collection_name=self.parent_collection, data=parent_entities)
  229. except Exception as e:
  230. logger.error(f"父块存入 Milvus 失败: {e}")
  231. raise RuntimeError(f"向量数据库写入失败(Parent): {str(e)}")
  232. # 处理子块
  233. if child_entities:
  234. try:
  235. c_texts = [e['text'] for e in child_entities]
  236. c_vecs = self.emb.embed_documents(c_texts)
  237. for e, v in zip(child_entities, c_vecs): e['dense'] = v
  238. except Exception as e:
  239. logger.error(f"子块向量化失败 (Embedding Service): {e}")
  240. raise RuntimeError(f"Embedding 服务调用失败: {str(e)}")
  241. try:
  242. self.client.insert(collection_name=self.child_collection, data=child_entities)
  243. except Exception as e:
  244. logger.error(f"子块存入 Milvus 失败: {e}")
  245. raise RuntimeError(f"向量数据库写入失败(Child): {str(e)}")
  246. logger.info(f"Successfully entered knowledge base for doc_id: {doc_id}, parents: {len(parent_entities)}, children: {len(child_entities)}")
  247. return len(parent_entities), len(child_entities)
  248. except Exception as e:
  249. # 重新抛出已处理的异常或包装未处理的异常
  250. if not isinstance(e, RuntimeError):
  251. logger.exception(f"入库流程发生未知异常 (doc_id: {doc_id})")
  252. raise RuntimeError(f"入库未知错误: {str(e)}")
  253. raise e
  254. # 全局 Milvus 服务实例
  255. milvus_host = config_handler.get("admin_app", "MILVUS_HOST", "192.168.92.61")
  256. milvus_port = config_handler.get("admin_app", "MILVUS_PORT", "19530")
  257. milvus_service = MilvusService(
  258. uri=f"http://{milvus_host}:{milvus_port}",
  259. db_name=config_handler.get("admin_app", "MILVUS_DB", "lq_db"),
  260. parent_collection=config_handler.get("admin_app", "PARENT_COLLECTION_NAME", "test_27_parent"),
  261. child_collection=config_handler.get("admin_app", "CHILD_COLLECTION_NAME", "test_27_child")
  262. )