| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446 |
- import os
- import re
- import json
- import hashlib
- import logging
- import time
- from typing import List, Dict, Any, Optional, Tuple
- from langchain_core.documents import Document
- from pymilvus import (
- MilvusClient,
- DataType,
- Function,
- FunctionType,
- )
- # 导入项目配置和连接
- from app.core.config import config_handler
- from app.base.milvus_connection import get_milvus_manager
- from app.base.embedding_connection import get_embedding_model
- logger = logging.getLogger(__name__)
- # =============================
- # 一、配置区 (从项目配置中读取默认值)
- # =============================
- # 默认处理目录
- ROOT_DIR = config_handler.get("admin_app", "MILVUS_IMPORT_ROOT", r"C:\Users\ZengChao\Desktop\新建文件夹")
- # ✅ 父表 / 子表
- PARENT_COLLECTION_NAME = config_handler.get("admin_app", "PARENT_COLLECTION_NAME", "test_27_parent")
- CHILD_COLLECTION_NAME = config_handler.get("admin_app", "CHILD_COLLECTION_NAME", "test_27_child")
- DENSE_DIM_FALLBACK = 4096
- CHUNK_ID_START = 0
- # ✅ 父段最大长度(超过就把父段切成多条父表记录,但它们 parent_id 相同)
- PARENT_MAX_CHARS = 6000
- # ✅ 标量字段(用于过滤)
- BASE_SCALAR_FIELDS = {
- "is_deleted": 0,
- "parent_id": "", # ✅ 字符串格式的 SHA-1
- "document_id": "DOC_123",
- "index": 0,
- "tag_list": "policy,hr",
- "permission": {},
- "created_by": "system",
- "created_time": int(time.time() * 1000),
- "updated_by": "system",
- "updated_time": int(time.time() * 1000),
- }
- # =============================
- # 二、工具:Markdown 切块
- # =============================
- BLANK_SPLIT_RE = re.compile(r"\n\s*\n+")
- H1_RE = re.compile(r"^#\s+(.+?)\s*$", re.MULTILINE)
- def split_md_by_blank_lines(md: str) -> List[str]:
- md = md.replace("\r\n", "\n").replace("\r", "\n")
- parts = BLANK_SPLIT_RE.split(md)
- return [p.strip() for p in parts if p.strip()]
- def is_heading_chunk(chunk: str):
- first_line = chunk.split("\n", 1)[0].strip()
- m = re.match(r"^(#{1,6})\s+(.+?)\s*$", first_line)
- if not m:
- return None
- return len(m.group(1)), m.group(2).strip()
- def outline_path_str(path: List[str]) -> str:
- return " > ".join(path)
- def guess_doc_name_from_filename(file_name: str) -> str:
- return os.path.splitext(file_name)[0]
- def split_md_by_h1_sections(md: str) -> List[Tuple[str, str]]:
- """
- 按 '# 一级标题' 切成父段:
- return: [(h1_title, section_text), ...]
- - 如果最开始有内容(第一个#之前),将其作为 "__PREAMBLE__" 段
- - section_text 包含该 # 行本身 + 直到下一个 # 之前的所有内容
- - 如果全文没有任何 #,则返回一个默认段 ("__NO_H1__", 全文)
- """
- md = md.replace("\r\n", "\n").replace("\r", "\n")
- matches = list(H1_RE.finditer(md))
- if not matches:
- txt = md.strip()
- if not txt:
- return []
- return [("__NO_H1__", txt)]
- sections: List[Tuple[str, str]] = []
-
- # 检查第一个#之前是否有内容
- first_match_start = matches[0].start()
- preamble = md[:first_match_start].strip()
- if preamble:
- sections.append(("__PREAMBLE__", preamble))
-
- # 处理所有#标题段
- for i, m in enumerate(matches):
- title = m.group(1).strip()
- start = m.start()
- end = matches[i + 1].start() if i + 1 < len(matches) else len(md)
- sec = md[start:end].strip()
- if sec:
- sections.append((title, sec))
- return sections
- def make_parent_id(doc_id: str, doc_version: int, doc_name: str, h1_title: str, parent_seq: int) -> str:
- """
- ✅ 生成稳定 parent_id(父段ID)
- 同一个 # 一级标题段无论父表切成几条记录,都共享同一个 parent_id
- """
- raw = f"{doc_id}|{doc_version}|{doc_name}|{parent_seq}|{h1_title}".encode("utf-8")
- return hashlib.sha1(raw).hexdigest()
- def split_text_by_max_chars(text: str, max_chars: int) -> List[str]:
- """
- 父段过长时切片:
- - 优先在最大长度附近的空行边界切割
- - 单个段落超过max_chars时才硬切
- """
- text = (text or "").strip()
- if not text:
- return []
- if len(text) <= max_chars:
- return [text]
- # 先按空行切割
- chunks = split_md_by_blank_lines(text)
-
- result = []
- current_slice = ""
-
- for chunk in chunks:
- # 如果单个chunk超过max_chars,必须硬切
- if len(chunk) > max_chars:
- # 先保存当前累积的内容
- if current_slice.strip():
- result.append(current_slice.strip())
- current_slice = ""
- # 对超长chunk硬切
- start = 0
- while start < len(chunk):
- result.append(chunk[start : start + max_chars].strip())
- start += max_chars
- else:
- # 尝试把chunk加入current_slice
- test_slice = current_slice + "\n\n" + chunk if current_slice else chunk
- if len(test_slice) <= max_chars:
- # 可以加入
- current_slice = test_slice
- else:
- # 超过了,保存current_slice,开始新的
- if current_slice.strip():
- result.append(current_slice.strip())
- current_slice = chunk
-
- # 保存最后的current_slice
- if current_slice.strip():
- result.append(current_slice.strip())
-
- return [s for s in result if s]
- def build_parent_and_child_documents_from_md(md_text: str, file_name: str) -> Tuple[List[Document], List[Document]]:
- """
- ✅ 切分顺序:
- 1. 先按 # 一级标题切父块
- 2. 用切好的父块来处理子块(按空行切)
- 3. 最后处理超长的父块(父块太长再切成多条父记录,共享同一个 parent_id)
- """
- doc_name = guess_doc_name_from_filename(file_name)
- doc_version = 20260127 # 默认版本
- # 1) 按 # 一级标题切父块
- parent_sections = split_md_by_h1_sections(md_text)
- parent_seq_to_id: Dict[int, str] = {}
- # 先生成所有 parent_id
- for parent_seq, (h1_title, sec_text) in enumerate(parent_sections):
- p_id = make_parent_id(
- doc_id=str(BASE_SCALAR_FIELDS["document_id"]),
- doc_version=doc_version,
- doc_name=doc_name,
- h1_title=h1_title,
- parent_seq=parent_seq,
- )
- parent_seq_to_id[parent_seq] = p_id
- # 2) 用切好的父块来处理子块(按空行切,但在父块范围内)
- child_docs: List[Document] = []
- for parent_seq, (h1_title, sec_text) in enumerate(parent_sections):
- p_id = parent_seq_to_id[parent_seq]
-
- # 在该父块范围内按空行切子块
- chunks = split_md_by_blank_lines(sec_text)
- heading_path: List[str] = []
- for c_idx, chunk in enumerate(chunks):
- # 子 chunk outline_path
- heading_info = is_heading_chunk(chunk)
- if heading_info:
- level, title = heading_info
- parent_path = heading_path[: level - 1]
- outline_path = outline_path_str(parent_path)
- heading_path = parent_path + [title]
- else:
- outline_path = outline_path_str(heading_path)
- scalar_md = dict(BASE_SCALAR_FIELDS)
- scalar_md["index"] = int(c_idx)
- scalar_md["parent_id"] = p_id
- # ✅ metadata 包含:doc_name, outline_path, doc_version
- metadata_json = {
- "doc_name": doc_name,
- "outline_path": outline_path,
- "doc_version": doc_version,
- }
- child_docs.append(
- Document(
- page_content=chunk,
- metadata={**scalar_md, "metadata": metadata_json},
- )
- )
- # 3) 处理超长的父块(父块太长再切成多条父记录)
- parent_docs: List[Document] = []
- for parent_seq, (h1_title, sec_text) in enumerate(parent_sections):
- p_id = parent_seq_to_id[parent_seq]
-
- # 如果父块过长,切成多条
- slices = split_text_by_max_chars(sec_text, PARENT_MAX_CHARS)
- for slice_idx, slice_text in enumerate(slices):
- scalar_md = dict(BASE_SCALAR_FIELDS)
- scalar_md["index"] = int(parent_seq)
- scalar_md["parent_id"] = p_id
- # ✅ metadata 包含:doc_name, outline_path, doc_version
- if h1_title == "__PREAMBLE__":
- outline_path = doc_name
- elif h1_title == "__NO_H1__":
- outline_path = ""
- else:
- outline_path = h1_title
-
- metadata_json = {
- "doc_name": doc_name,
- "outline_path": outline_path,
- "doc_version": doc_version,
- }
- parent_docs.append(
- Document(
- page_content=slice_text,
- metadata={**scalar_md, "metadata": metadata_json},
- )
- )
- return parent_docs, child_docs
- def save_docs_to_json(docs: List[Document], out_path: str) -> str:
- if not docs:
- return ""
- docs_data = [{"page_content": d.page_content, "metadata": d.metadata} for d in docs]
- with open(out_path, "w", encoding="utf-8") as f:
- json.dump(docs_data, f, ensure_ascii=False, indent=2)
- return out_path
- # =============================
- # 三、Milvus:建 collection(dense + BM25 + 标量字段 + JSON metadata)
- # =============================
- def detect_dense_dim(emb) -> int:
- try:
- return len(emb.embed_query("dim probe"))
- except:
- return DENSE_DIM_FALLBACK
- def ensure_collection(client: MilvusClient, collection_name: str, dense_dim: int):
- if client.has_collection(collection_name=collection_name):
- return
- schema = client.create_schema(auto_id=True, enable_dynamic_fields=False)
- schema.add_field("pk", DataType.INT64, is_primary=True, auto_id=True)
- schema.add_field("text", DataType.VARCHAR, max_length=65535, enable_analyzer=True)
- schema.add_field("dense", DataType.FLOAT_VECTOR, dim=dense_dim)
- schema.add_field("sparse", DataType.SPARSE_FLOAT_VECTOR)
- schema.add_field("document_id", DataType.VARCHAR, max_length=256)
- schema.add_field("parent_id", DataType.VARCHAR, max_length=256)
- schema.add_field("index", DataType.INT64)
- schema.add_field("tag_list", DataType.VARCHAR, max_length=2048)
- schema.add_field("permission", DataType.JSON)
- schema.add_field("metadata", DataType.JSON)
- schema.add_field("is_deleted", DataType.INT64)
- schema.add_field("created_by", DataType.VARCHAR, max_length=256)
- schema.add_field("created_time", DataType.INT64)
- schema.add_field("updated_by", DataType.VARCHAR, max_length=256)
- schema.add_field("updated_time", DataType.INT64)
- schema.add_function(
- Function(
- name="bm25_fn",
- input_field_names=["text"],
- output_field_names=["sparse"],
- function_type=FunctionType.BM25,
- )
- )
- client.create_collection(collection_name=collection_name, schema=schema)
- index_params = client.prepare_index_params()
- index_params.add_index(
- field_name="dense",
- index_name="dense_idx",
- index_type="AUTOINDEX",
- metric_type="COSINE",
- )
- index_params.add_index(
- field_name="sparse",
- index_name="bm25_idx",
- index_type="SPARSE_INVERTED_INDEX",
- metric_type="BM25",
- params={"inverted_index_algo": "DAAT_MAXSCORE"},
- )
- client.create_index(collection_name=collection_name, index_params=index_params)
- client.load_collection(collection_name=collection_name)
- # =============================
- # 四、写入:dense 由 embedding 生成;BM25 由 Milvus 自动生成
- # =============================
- def docs_to_entities(docs: List[Document], emb) -> List[Dict[str, Any]]:
- texts = [d.page_content for d in docs]
- dense_vecs = emb.embed_documents(texts)
- entities: List[Dict[str, Any]] = []
- for d, vec in zip(docs, dense_vecs):
- md = d.metadata or {}
- entities.append(
- {
- "text": d.page_content,
- "dense": vec,
- "is_deleted": int(md.get("is_deleted", 0)),
- "parent_id": str(md.get("parent_id", "")),
- "document_id": str(md.get("document_id", "")),
- "index": int(md.get("index", 0)),
- "tag_list": str(md.get("tag_list", "")),
- "permission": md.get("permission", {}) if isinstance(md.get("permission", {}), dict) else {},
- "metadata": md.get("metadata", {}) if isinstance(md.get("metadata", {}), dict) else {},
- "created_by": str(md.get("created_by", "system")),
- "created_time": int(md.get("created_time", int(time.time() * 1000))),
- "updated_by": str(md.get("updated_by", "system")),
- "updated_time": int(md.get("updated_time", int(time.time() * 1000))),
- }
- )
- return entities
- def insert_docs(client: MilvusClient, emb, docs: List[Document], collection_name: str):
- if not docs:
- return
- entities = docs_to_entities(docs, emb)
- client.insert(collection_name=collection_name, data=entities)
- # =============================
- # 五、主程序:只负责入库(父表 + 子表)
- # =============================
- if __name__ == "__main__":
- # 使用项目统一的 embedding 模型
- emb = get_embedding_model()
- try:
- dense_dim = detect_dense_dim(emb)
- except Exception:
- dense_dim = DENSE_DIM_FALLBACK
- # 使用项目统一的 Milvus 管理器
- milvus_manager = get_milvus_manager()
- client = milvus_manager.client
- # ✅ 建两个表:父表 + 子表
- ensure_collection(client, PARENT_COLLECTION_NAME, dense_dim=dense_dim)
- ensure_collection(client, CHILD_COLLECTION_NAME, dense_dim=dense_dim)
- if not os.path.exists(ROOT_DIR):
- print(f"❌ 目录不存在:{ROOT_DIR}")
- else:
- for folder_name in os.listdir(ROOT_DIR):
- folder_path = os.path.join(ROOT_DIR, folder_name)
- if not os.path.isdir(folder_path):
- continue
- for file_name in os.listdir(folder_path):
- if not file_name.lower().endswith(".md"):
- continue
- md_path = os.path.join(folder_path, file_name)
- try:
- print(f"\n📄 正在处理:{md_path}")
- with open(md_path, "r", encoding="utf-8") as f:
- text = f.read()
- parent_docs, child_docs = build_parent_and_child_documents_from_md(text, file_name)
- # 可选:落盘看切分效果
- out_dir = os.path.dirname(md_path)
- base = os.path.splitext(os.path.basename(md_path))[0]
- save_docs_to_json(parent_docs, os.path.join(out_dir, f"{base}_parents.json"))
- save_docs_to_json(child_docs, os.path.join(out_dir, f"{base}_children.json"))
- # ✅ 写父表 & 子表
- insert_docs(client, emb, parent_docs, PARENT_COLLECTION_NAME)
- insert_docs(client, emb, child_docs, CHILD_COLLECTION_NAME)
- print(f"✅ 父表写入:parents={len(parent_docs)} -> {PARENT_COLLECTION_NAME}")
- print(f"✅ 子表写入:children={len(child_docs)} -> {CHILD_COLLECTION_NAME}")
- except Exception as e:
- print(f"❌ 处理失败:{md_path}")
- print(e)
|