pymilvus_store_database.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. import os
  2. import re
  3. import json
  4. import hashlib
  5. import logging
  6. import time
  7. from typing import List, Dict, Any, Optional, Tuple
  8. from langchain_core.documents import Document
  9. from pymilvus import (
  10. MilvusClient,
  11. DataType,
  12. Function,
  13. FunctionType,
  14. )
  15. # 导入项目配置和连接
  16. from app.core.config import config_handler
  17. from app.base.milvus_connection import get_milvus_manager
  18. from app.base.embedding_connection import get_embedding_model
  19. logger = logging.getLogger(__name__)
  20. # =============================
  21. # 一、配置区 (从项目配置中读取默认值)
  22. # =============================
  23. # 默认处理目录
  24. ROOT_DIR = config_handler.get("admin_app", "MILVUS_IMPORT_ROOT", r"C:\Users\ZengChao\Desktop\新建文件夹")
  25. # ✅ 父表 / 子表
  26. PARENT_COLLECTION_NAME = config_handler.get("admin_app", "PARENT_COLLECTION_NAME", "test_27_parent")
  27. CHILD_COLLECTION_NAME = config_handler.get("admin_app", "CHILD_COLLECTION_NAME", "test_27_child")
  28. DENSE_DIM_FALLBACK = 4096
  29. CHUNK_ID_START = 0
  30. # ✅ 父段最大长度(超过就把父段切成多条父表记录,但它们 parent_id 相同)
  31. PARENT_MAX_CHARS = 6000
  32. # ✅ 标量字段(用于过滤)
  33. BASE_SCALAR_FIELDS = {
  34. "is_deleted": 0,
  35. "parent_id": "", # ✅ 字符串格式的 SHA-1
  36. "document_id": "DOC_123",
  37. "index": 0,
  38. "tag_list": "policy,hr",
  39. "permission": {},
  40. "created_by": "system",
  41. "created_time": int(time.time() * 1000),
  42. "updated_by": "system",
  43. "updated_time": int(time.time() * 1000),
  44. }
  45. # =============================
  46. # 二、工具:Markdown 切块
  47. # =============================
  48. BLANK_SPLIT_RE = re.compile(r"\n\s*\n+")
  49. H1_RE = re.compile(r"^#\s+(.+?)\s*$", re.MULTILINE)
  50. def split_md_by_blank_lines(md: str) -> List[str]:
  51. md = md.replace("\r\n", "\n").replace("\r", "\n")
  52. parts = BLANK_SPLIT_RE.split(md)
  53. return [p.strip() for p in parts if p.strip()]
  54. def is_heading_chunk(chunk: str):
  55. first_line = chunk.split("\n", 1)[0].strip()
  56. m = re.match(r"^(#{1,6})\s+(.+?)\s*$", first_line)
  57. if not m:
  58. return None
  59. return len(m.group(1)), m.group(2).strip()
  60. def outline_path_str(path: List[str]) -> str:
  61. return " > ".join(path)
  62. def guess_doc_name_from_filename(file_name: str) -> str:
  63. return os.path.splitext(file_name)[0]
  64. def split_md_by_h1_sections(md: str) -> List[Tuple[str, str]]:
  65. """
  66. 按 '# 一级标题' 切成父段:
  67. return: [(h1_title, section_text), ...]
  68. - 如果最开始有内容(第一个#之前),将其作为 "__PREAMBLE__" 段
  69. - section_text 包含该 # 行本身 + 直到下一个 # 之前的所有内容
  70. - 如果全文没有任何 #,则返回一个默认段 ("__NO_H1__", 全文)
  71. """
  72. md = md.replace("\r\n", "\n").replace("\r", "\n")
  73. matches = list(H1_RE.finditer(md))
  74. if not matches:
  75. txt = md.strip()
  76. if not txt:
  77. return []
  78. return [("__NO_H1__", txt)]
  79. sections: List[Tuple[str, str]] = []
  80. # 检查第一个#之前是否有内容
  81. first_match_start = matches[0].start()
  82. preamble = md[:first_match_start].strip()
  83. if preamble:
  84. sections.append(("__PREAMBLE__", preamble))
  85. # 处理所有#标题段
  86. for i, m in enumerate(matches):
  87. title = m.group(1).strip()
  88. start = m.start()
  89. end = matches[i + 1].start() if i + 1 < len(matches) else len(md)
  90. sec = md[start:end].strip()
  91. if sec:
  92. sections.append((title, sec))
  93. return sections
  94. def make_parent_id(doc_id: str, doc_version: int, doc_name: str, h1_title: str, parent_seq: int) -> str:
  95. """
  96. ✅ 生成稳定 parent_id(父段ID)
  97. 同一个 # 一级标题段无论父表切成几条记录,都共享同一个 parent_id
  98. """
  99. raw = f"{doc_id}|{doc_version}|{doc_name}|{parent_seq}|{h1_title}".encode("utf-8")
  100. return hashlib.sha1(raw).hexdigest()
  101. def split_text_by_max_chars(text: str, max_chars: int) -> List[str]:
  102. """
  103. 父段过长时切片:
  104. - 优先在最大长度附近的空行边界切割
  105. - 单个段落超过max_chars时才硬切
  106. """
  107. text = (text or "").strip()
  108. if not text:
  109. return []
  110. if len(text) <= max_chars:
  111. return [text]
  112. # 先按空行切割
  113. chunks = split_md_by_blank_lines(text)
  114. result = []
  115. current_slice = ""
  116. for chunk in chunks:
  117. # 如果单个chunk超过max_chars,必须硬切
  118. if len(chunk) > max_chars:
  119. # 先保存当前累积的内容
  120. if current_slice.strip():
  121. result.append(current_slice.strip())
  122. current_slice = ""
  123. # 对超长chunk硬切
  124. start = 0
  125. while start < len(chunk):
  126. result.append(chunk[start : start + max_chars].strip())
  127. start += max_chars
  128. else:
  129. # 尝试把chunk加入current_slice
  130. test_slice = current_slice + "\n\n" + chunk if current_slice else chunk
  131. if len(test_slice) <= max_chars:
  132. # 可以加入
  133. current_slice = test_slice
  134. else:
  135. # 超过了,保存current_slice,开始新的
  136. if current_slice.strip():
  137. result.append(current_slice.strip())
  138. current_slice = chunk
  139. # 保存最后的current_slice
  140. if current_slice.strip():
  141. result.append(current_slice.strip())
  142. return [s for s in result if s]
  143. def build_parent_and_child_documents_from_md(md_text: str, file_name: str) -> Tuple[List[Document], List[Document]]:
  144. """
  145. ✅ 切分顺序:
  146. 1. 先按 # 一级标题切父块
  147. 2. 用切好的父块来处理子块(按空行切)
  148. 3. 最后处理超长的父块(父块太长再切成多条父记录,共享同一个 parent_id)
  149. """
  150. doc_name = guess_doc_name_from_filename(file_name)
  151. doc_version = 20260127 # 默认版本
  152. # 1) 按 # 一级标题切父块
  153. parent_sections = split_md_by_h1_sections(md_text)
  154. parent_seq_to_id: Dict[int, str] = {}
  155. # 先生成所有 parent_id
  156. for parent_seq, (h1_title, sec_text) in enumerate(parent_sections):
  157. p_id = make_parent_id(
  158. doc_id=str(BASE_SCALAR_FIELDS["document_id"]),
  159. doc_version=doc_version,
  160. doc_name=doc_name,
  161. h1_title=h1_title,
  162. parent_seq=parent_seq,
  163. )
  164. parent_seq_to_id[parent_seq] = p_id
  165. # 2) 用切好的父块来处理子块(按空行切,但在父块范围内)
  166. child_docs: List[Document] = []
  167. for parent_seq, (h1_title, sec_text) in enumerate(parent_sections):
  168. p_id = parent_seq_to_id[parent_seq]
  169. # 在该父块范围内按空行切子块
  170. chunks = split_md_by_blank_lines(sec_text)
  171. heading_path: List[str] = []
  172. for c_idx, chunk in enumerate(chunks):
  173. # 子 chunk outline_path
  174. heading_info = is_heading_chunk(chunk)
  175. if heading_info:
  176. level, title = heading_info
  177. parent_path = heading_path[: level - 1]
  178. outline_path = outline_path_str(parent_path)
  179. heading_path = parent_path + [title]
  180. else:
  181. outline_path = outline_path_str(heading_path)
  182. scalar_md = dict(BASE_SCALAR_FIELDS)
  183. scalar_md["index"] = int(c_idx)
  184. scalar_md["parent_id"] = p_id
  185. # ✅ metadata 包含:doc_name, outline_path, doc_version
  186. metadata_json = {
  187. "doc_name": doc_name,
  188. "outline_path": outline_path,
  189. "doc_version": doc_version,
  190. }
  191. child_docs.append(
  192. Document(
  193. page_content=chunk,
  194. metadata={**scalar_md, "metadata": metadata_json},
  195. )
  196. )
  197. # 3) 处理超长的父块(父块太长再切成多条父记录)
  198. parent_docs: List[Document] = []
  199. for parent_seq, (h1_title, sec_text) in enumerate(parent_sections):
  200. p_id = parent_seq_to_id[parent_seq]
  201. # 如果父块过长,切成多条
  202. slices = split_text_by_max_chars(sec_text, PARENT_MAX_CHARS)
  203. for slice_idx, slice_text in enumerate(slices):
  204. scalar_md = dict(BASE_SCALAR_FIELDS)
  205. scalar_md["index"] = int(parent_seq)
  206. scalar_md["parent_id"] = p_id
  207. # ✅ metadata 包含:doc_name, outline_path, doc_version
  208. if h1_title == "__PREAMBLE__":
  209. outline_path = doc_name
  210. elif h1_title == "__NO_H1__":
  211. outline_path = ""
  212. else:
  213. outline_path = h1_title
  214. metadata_json = {
  215. "doc_name": doc_name,
  216. "outline_path": outline_path,
  217. "doc_version": doc_version,
  218. }
  219. parent_docs.append(
  220. Document(
  221. page_content=slice_text,
  222. metadata={**scalar_md, "metadata": metadata_json},
  223. )
  224. )
  225. return parent_docs, child_docs
  226. def save_docs_to_json(docs: List[Document], out_path: str) -> str:
  227. if not docs:
  228. return ""
  229. docs_data = [{"page_content": d.page_content, "metadata": d.metadata} for d in docs]
  230. with open(out_path, "w", encoding="utf-8") as f:
  231. json.dump(docs_data, f, ensure_ascii=False, indent=2)
  232. return out_path
  233. # =============================
  234. # 三、Milvus:建 collection(dense + BM25 + 标量字段 + JSON metadata)
  235. # =============================
  236. def detect_dense_dim(emb) -> int:
  237. try:
  238. return len(emb.embed_query("dim probe"))
  239. except:
  240. return DENSE_DIM_FALLBACK
  241. def ensure_collection(client: MilvusClient, collection_name: str, dense_dim: int):
  242. if client.has_collection(collection_name=collection_name):
  243. return
  244. schema = client.create_schema(auto_id=True, enable_dynamic_fields=False)
  245. schema.add_field("pk", DataType.INT64, is_primary=True, auto_id=True)
  246. schema.add_field("text", DataType.VARCHAR, max_length=65535, enable_analyzer=True)
  247. schema.add_field("dense", DataType.FLOAT_VECTOR, dim=dense_dim)
  248. schema.add_field("sparse", DataType.SPARSE_FLOAT_VECTOR)
  249. schema.add_field("document_id", DataType.VARCHAR, max_length=256)
  250. schema.add_field("parent_id", DataType.VARCHAR, max_length=256)
  251. schema.add_field("index", DataType.INT64)
  252. schema.add_field("tag_list", DataType.VARCHAR, max_length=2048)
  253. schema.add_field("permission", DataType.JSON)
  254. schema.add_field("metadata", DataType.JSON)
  255. schema.add_field("is_deleted", DataType.INT64)
  256. schema.add_field("created_by", DataType.VARCHAR, max_length=256)
  257. schema.add_field("created_time", DataType.INT64)
  258. schema.add_field("updated_by", DataType.VARCHAR, max_length=256)
  259. schema.add_field("updated_time", DataType.INT64)
  260. schema.add_function(
  261. Function(
  262. name="bm25_fn",
  263. input_field_names=["text"],
  264. output_field_names=["sparse"],
  265. function_type=FunctionType.BM25,
  266. )
  267. )
  268. client.create_collection(collection_name=collection_name, schema=schema)
  269. index_params = client.prepare_index_params()
  270. index_params.add_index(
  271. field_name="dense",
  272. index_name="dense_idx",
  273. index_type="AUTOINDEX",
  274. metric_type="COSINE",
  275. )
  276. index_params.add_index(
  277. field_name="sparse",
  278. index_name="bm25_idx",
  279. index_type="SPARSE_INVERTED_INDEX",
  280. metric_type="BM25",
  281. params={"inverted_index_algo": "DAAT_MAXSCORE"},
  282. )
  283. client.create_index(collection_name=collection_name, index_params=index_params)
  284. client.load_collection(collection_name=collection_name)
  285. # =============================
  286. # 四、写入:dense 由 embedding 生成;BM25 由 Milvus 自动生成
  287. # =============================
  288. def docs_to_entities(docs: List[Document], emb) -> List[Dict[str, Any]]:
  289. texts = [d.page_content for d in docs]
  290. dense_vecs = emb.embed_documents(texts)
  291. entities: List[Dict[str, Any]] = []
  292. for d, vec in zip(docs, dense_vecs):
  293. md = d.metadata or {}
  294. entities.append(
  295. {
  296. "text": d.page_content,
  297. "dense": vec,
  298. "is_deleted": int(md.get("is_deleted", 0)),
  299. "parent_id": str(md.get("parent_id", "")),
  300. "document_id": str(md.get("document_id", "")),
  301. "index": int(md.get("index", 0)),
  302. "tag_list": str(md.get("tag_list", "")),
  303. "permission": md.get("permission", {}) if isinstance(md.get("permission", {}), dict) else {},
  304. "metadata": md.get("metadata", {}) if isinstance(md.get("metadata", {}), dict) else {},
  305. "created_by": str(md.get("created_by", "system")),
  306. "created_time": int(md.get("created_time", int(time.time() * 1000))),
  307. "updated_by": str(md.get("updated_by", "system")),
  308. "updated_time": int(md.get("updated_time", int(time.time() * 1000))),
  309. }
  310. )
  311. return entities
  312. def insert_docs(client: MilvusClient, emb, docs: List[Document], collection_name: str):
  313. if not docs:
  314. return
  315. entities = docs_to_entities(docs, emb)
  316. client.insert(collection_name=collection_name, data=entities)
  317. # =============================
  318. # 五、主程序:只负责入库(父表 + 子表)
  319. # =============================
  320. if __name__ == "__main__":
  321. # 使用项目统一的 embedding 模型
  322. emb = get_embedding_model()
  323. try:
  324. dense_dim = detect_dense_dim(emb)
  325. except Exception:
  326. dense_dim = DENSE_DIM_FALLBACK
  327. # 使用项目统一的 Milvus 管理器
  328. milvus_manager = get_milvus_manager()
  329. client = milvus_manager.client
  330. # ✅ 建两个表:父表 + 子表
  331. ensure_collection(client, PARENT_COLLECTION_NAME, dense_dim=dense_dim)
  332. ensure_collection(client, CHILD_COLLECTION_NAME, dense_dim=dense_dim)
  333. if not os.path.exists(ROOT_DIR):
  334. print(f"❌ 目录不存在:{ROOT_DIR}")
  335. else:
  336. for folder_name in os.listdir(ROOT_DIR):
  337. folder_path = os.path.join(ROOT_DIR, folder_name)
  338. if not os.path.isdir(folder_path):
  339. continue
  340. for file_name in os.listdir(folder_path):
  341. if not file_name.lower().endswith(".md"):
  342. continue
  343. md_path = os.path.join(folder_path, file_name)
  344. try:
  345. print(f"\n📄 正在处理:{md_path}")
  346. with open(md_path, "r", encoding="utf-8") as f:
  347. text = f.read()
  348. parent_docs, child_docs = build_parent_and_child_documents_from_md(text, file_name)
  349. # 可选:落盘看切分效果
  350. out_dir = os.path.dirname(md_path)
  351. base = os.path.splitext(os.path.basename(md_path))[0]
  352. save_docs_to_json(parent_docs, os.path.join(out_dir, f"{base}_parents.json"))
  353. save_docs_to_json(child_docs, os.path.join(out_dir, f"{base}_children.json"))
  354. # ✅ 写父表 & 子表
  355. insert_docs(client, emb, parent_docs, PARENT_COLLECTION_NAME)
  356. insert_docs(client, emb, child_docs, CHILD_COLLECTION_NAME)
  357. print(f"✅ 父表写入:parents={len(parent_docs)} -> {PARENT_COLLECTION_NAME}")
  358. print(f"✅ 子表写入:children={len(child_docs)} -> {CHILD_COLLECTION_NAME}")
  359. except Exception as e:
  360. print(f"❌ 处理失败:{md_path}")
  361. print(e)