split_model.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501
  1. # coding=utf-8
  2. """
  3. @project: qabot
  4. @Author:虎
  5. @file: split_model.py
  6. @date:2023/9/1 15:12
  7. @desc:
  8. """
  9. import re
  10. from functools import reduce
  11. from typing import List, Dict
  12. import jieba
  13. def get_level_block(text, level_content_list, level_content_index, cursor):
  14. """
  15. 从文本中获取块数据
  16. :param text: 文本
  17. :param level_content_list: 拆分的title数组
  18. :param level_content_index: 指定的下标
  19. :param cursor: 开始的下标位置
  20. :return: 拆分后的文本数据
  21. """
  22. start_content: str = level_content_list[level_content_index].get('content')
  23. next_content = level_content_list[level_content_index + 1].get("content") if level_content_index + 1 < len(
  24. level_content_list) else None
  25. start_index = text.index(start_content, cursor)
  26. end_index = text.index(next_content, start_index + 1) if next_content is not None else len(text)
  27. return text[start_index + len(start_content):end_index], end_index
  28. def to_tree_obj(content, state='title'):
  29. """
  30. 转换为树形对象
  31. :param content: 文本数据
  32. :param state: 状态: title block
  33. :return: 转换后的数据
  34. """
  35. return {'content': content, 'state': state}
  36. def remove_special_symbol(str_source: str):
  37. """
  38. 删除特殊字符
  39. :param str_source: 需要删除的文本数据
  40. :return: 删除后的数据
  41. """
  42. return str_source
  43. def filter_special_symbol(content: dict):
  44. """
  45. 过滤文本中的特殊字符
  46. :param content: 需要过滤的对象
  47. :return: 过滤后返回
  48. """
  49. content['content'] = remove_special_symbol(content['content'])
  50. return content
  51. def flat(tree_data_list: List[dict], parent_chain: List[dict], result: List[dict]):
  52. """
  53. 扁平化树形结构数据
  54. :param tree_data_list: 树形接口数据
  55. :param parent_chain: 父级数据 传[] 用于递归存储数据
  56. :param result: 响应数据 传[] 用于递归存放数据
  57. :return: result 扁平化后的数据
  58. """
  59. if parent_chain is None:
  60. parent_chain = []
  61. if result is None:
  62. result = []
  63. for tree_data in tree_data_list:
  64. p = parent_chain.copy()
  65. p.append(tree_data)
  66. result.append(to_flat_obj(parent_chain, content=tree_data["content"], state=tree_data["state"]))
  67. children = tree_data.get('children')
  68. if children is not None and len(children) > 0:
  69. flat(children, p, result)
  70. return result
  71. def to_paragraph(obj: dict):
  72. """
  73. 转换为段落
  74. :param obj: 需要转换的对象
  75. :return: 段落对象
  76. """
  77. content = obj['content']
  78. return {"keywords": get_keyword(content),
  79. 'parent_chain': list(map(lambda p: p['content'], obj['parent_chain'])),
  80. 'content': ",".join(list(map(lambda p: p['content'], obj['parent_chain']))) + content}
  81. def get_keyword(content: str):
  82. """
  83. 获取content中的关键词
  84. :param content: 文本
  85. :return: 关键词数组
  86. """
  87. stopwords = [':', '“', '!', '”', '\n', '\\s']
  88. cutworms = jieba.lcut(content)
  89. return list(set(list(filter(lambda k: (k not in stopwords) | len(k) > 1, cutworms))))
  90. def titles_to_paragraph(list_title: List[dict]):
  91. """
  92. 将同一父级的title转换为块段落
  93. :param list_title: 同父级title
  94. :return: 块段落
  95. """
  96. if len(list_title) > 0:
  97. content = "\n,".join(
  98. list(map(lambda d: d['content'].strip("\r\n").strip("\n").strip("\\s"), list_title)))
  99. return {'keywords': '',
  100. 'parent_chain': list(
  101. map(lambda p: p['content'].strip("\r\n").strip("\n").strip("\\s"), list_title[0]['parent_chain'])),
  102. 'content': ",".join(list(
  103. map(lambda p: p['content'].strip("\r\n").strip("\n").strip("\\s"),
  104. list_title[0]['parent_chain']))) + content}
  105. return None
  106. def parse_group_key(level_list: List[dict]):
  107. """
  108. 将同级别同父级的title生成段落,加上本身的段落数据形成新的数据
  109. :param level_list: title n 级数据
  110. :return: 根据title生成的数据 + 段落数据
  111. """
  112. result = []
  113. group_data = group_by(list(filter(lambda f: f['state'] == 'title' and len(f['parent_chain']) > 0, level_list)),
  114. key=lambda d: ",".join(list(map(lambda p: p['content'], d['parent_chain']))))
  115. result += list(map(lambda group_data_key: titles_to_paragraph(group_data[group_data_key]), group_data))
  116. result += list(map(to_paragraph, list(filter(lambda f: f['state'] == 'block', level_list))))
  117. return result
  118. def to_block_paragraph(tree_data_list: List[dict]):
  119. """
  120. 转换为块段落对象
  121. :param tree_data_list: 树数据
  122. :return: 块段落
  123. """
  124. flat_list = flat(tree_data_list, [], [])
  125. level_group_dict: dict = group_by(flat_list, key=lambda f: f['level'])
  126. return list(map(lambda level: parse_group_key(level_group_dict[level]), level_group_dict))
  127. def parse_title_level(text, content_level_pattern: List, index):
  128. if index >= len(content_level_pattern):
  129. return []
  130. result = parse_level(text, content_level_pattern[index])
  131. if len(result) == 0 and len(content_level_pattern) > index:
  132. return parse_title_level(text, content_level_pattern, index + 1)
  133. return result
  134. def mask_code_blocks(text: str) -> str:
  135. """
  136. 将代码块内容替换为等长空格,防止代码块内的#被识别为标题
  137. """
  138. result = list(text)
  139. for match in re.finditer(r'```[^\n]*\n.*?```', text, re.DOTALL):
  140. start = match.start()
  141. end = match.end()
  142. inner_start = text.index('\n', start) + 1
  143. closing_fence_start = text.rindex('```', start, end)
  144. for i in range(inner_start, closing_fence_start):
  145. if result[i] != '\n':
  146. result[i] = ' '
  147. return ''.join(result)
  148. def parse_level(text, pattern: str):
  149. """
  150. 获取正则匹配到的文本
  151. :param text: 需要匹配的文本
  152. :param pattern: 正则
  153. :return: 符合正则的文本
  154. """
  155. masked_text = mask_code_blocks(text)
  156. level_content_list = list(map(to_tree_obj, [r[0:255] for r in re_findall(pattern, masked_text) if r is not None]))
  157. # 过滤掉空标题或只包含#和空白字符的标题
  158. filtered_list = [item for item in level_content_list
  159. if item['content'].strip(' ') and item['content'].replace('#', '').strip(' ')]
  160. return list(map(filter_special_symbol, filtered_list))
  161. def re_findall(pattern, text):
  162. # 检查 pattern 是否为空或无效
  163. if pattern is None:
  164. return []
  165. # 如果是字符串类型,检查是否为空字符串
  166. if isinstance(pattern, str) and (not pattern or not pattern.strip()):
  167. return []
  168. try:
  169. result = re.findall(pattern, text, flags=0)
  170. except re.error:
  171. return []
  172. return list(filter(lambda r: r is not None and len(r) > 0, reduce(lambda x, y: [*x, *y], list(
  173. map(lambda row: [*(row if isinstance(row, tuple) else [row])], result)),
  174. [])))
  175. def to_flat_obj(parent_chain: List[dict], content: str, state: str):
  176. """
  177. 将树形属性转换为扁平对象
  178. :param parent_chain:
  179. :param content:
  180. :param state:
  181. :return:
  182. """
  183. return {'parent_chain': parent_chain, 'level': len(parent_chain), "content": content, 'state': state}
  184. def flat_map(array: List[List]):
  185. """
  186. 将二位数组转为一维数组
  187. :param array: 二维数组
  188. :return: 一维数组
  189. """
  190. result = []
  191. for e in array:
  192. result += e
  193. return result
  194. def group_by(list_source: List, key):
  195. """
  196. 將數組分組
  197. :param list_source: 需要分組的數組
  198. :param key: 分組函數
  199. :return: key->[]
  200. """
  201. result = {}
  202. for e in list_source:
  203. k = key(e)
  204. array = result.get(k) if k in result else []
  205. array.append(e)
  206. result[k] = array
  207. return result
  208. def result_tree_to_paragraph(result_tree: List[dict], result, parent_chain, with_filter: bool):
  209. """
  210. 转换为分段对象
  211. :param result_tree: 解析文本的树
  212. :param result: 传[] 用于递归
  213. :param parent_chain: 传[] 用户递归存储数据
  214. :param with_filter: 是否过滤block
  215. :return: List[{'problem':'xx','content':'xx'}]
  216. """
  217. for item in result_tree:
  218. if item.get('state') == 'block':
  219. result.append({'title': " ".join(parent_chain),
  220. 'content': filter_special_char(item.get("content")) if with_filter else item.get("content")})
  221. children = item.get("children")
  222. if children is not None and len(children) > 0:
  223. result_tree_to_paragraph(children, result,
  224. [*parent_chain, remove_special_symbol(item.get('content'))], with_filter)
  225. return result
  226. def post_handler_paragraph(content: str, limit: int):
  227. """
  228. 根据文本的最大字符分段
  229. :param content: 需要分段的文本字段
  230. :param limit: 最大分段字符
  231. :return: 分段后数据
  232. """
  233. result = []
  234. temp_char, start = '', 0
  235. while (pos := content.find("\n", start)) != -1:
  236. split, start = content[start:pos + 1], pos + 1
  237. if len(temp_char + split) > limit:
  238. if len(temp_char) > 4096:
  239. pass
  240. result.append(temp_char)
  241. temp_char = ''
  242. temp_char = temp_char + split
  243. temp_char = temp_char + content[start:]
  244. if len(temp_char) > 0:
  245. if len(temp_char) > 4096:
  246. pass
  247. result.append(temp_char)
  248. pattern = "[\\S\\s]{1," + str(limit) + '}'
  249. # 如果\n 单段超过限制,则继续拆分
  250. return reduce(lambda x, y: [*x, *y], map(lambda row: re.findall(pattern, row), result), [])
  251. def smart_split_paragraph(content: str, limit: int):
  252. """
  253. 智能分段:在limit前找到合适的分割点(句号、回车等)
  254. :param content: 需要分段的文本
  255. :param limit: 最大字符限制
  256. :return: 分段后的文本列表
  257. """
  258. if len(content) <= limit:
  259. return [content]
  260. result = []
  261. start = 0
  262. while start < len(content):
  263. end = start + limit
  264. if end >= len(content):
  265. # 剩余文本不超过限制,直接添加
  266. result.append(content[start:])
  267. break
  268. # 在limit范围内寻找最佳分割点
  269. best_split = end
  270. # 优先级:句号 > 感叹号/问号 > 回车
  271. split_chars = [
  272. ('。', 0), ('.', 0), # 中英文句号
  273. ('!', 0), ('!', 0), # 中英文感叹号
  274. ('?', 0), ('?', 0), # 中英文问号
  275. ]
  276. # 从后往前找分割点
  277. for i in range(end - 1, start + limit // 2, -1): # 至少保留一半内容
  278. for char, offset in split_chars:
  279. if content[i] == char:
  280. best_split = i + 1 # 包含分隔符在当前段
  281. break
  282. if best_split != end:
  283. break
  284. # 如果找不到合适分割点,使用原始limit
  285. if best_split == end and end < len(content):
  286. best_split = end
  287. result.append(content[start:best_split])
  288. start = best_split
  289. return [text for text in result if text.strip()]
  290. replace_map = {
  291. re.compile('\n+'): '\n',
  292. re.compile(' +'): ' ',
  293. re.compile('#+'): "",
  294. re.compile("\t+"): ''
  295. }
  296. def filter_special_char(content: str):
  297. """
  298. 过滤特殊字段
  299. :param content: 文本
  300. :return: 过滤后字段
  301. """
  302. items = replace_map.items()
  303. for key, value in items:
  304. content = re.sub(key, value, content)
  305. return content
  306. class SplitModel:
  307. def __init__(self, content_level_pattern, with_filter=True, limit=100000):
  308. self.content_level_pattern = content_level_pattern
  309. self.with_filter = with_filter
  310. if type(limit) is not int:
  311. limit = int(limit)
  312. if limit is None or limit > 100000:
  313. limit = 100000
  314. if limit < 50:
  315. limit = 50
  316. self.limit = limit
  317. def parse_to_tree(self, text: str, index=0):
  318. """
  319. 解析文本
  320. :param text: 需要解析的文本
  321. :param index: 从那个正则开始解析
  322. :return: 解析后的树形结果数据
  323. """
  324. level_content_list = parse_title_level(text, self.content_level_pattern, index)
  325. if len(level_content_list) == 0:
  326. return [to_tree_obj(row, 'block') for row in smart_split_paragraph(text, limit=self.limit)]
  327. if index == 0 and text.lstrip().index(level_content_list[0]["content"].lstrip()) != 0:
  328. level_content_list.insert(0, to_tree_obj(""))
  329. cursor = 0
  330. level_title_content_list = [item for item in level_content_list if item.get('state') == 'title']
  331. for i in range(len(level_title_content_list)):
  332. start_content: str = level_title_content_list[i].get('content')
  333. if cursor < text.index(start_content, cursor):
  334. for row in smart_split_paragraph(text[cursor: text.index(start_content, cursor)], limit=self.limit):
  335. level_content_list.insert(0, to_tree_obj(row, 'block'))
  336. block, cursor = get_level_block(text, level_title_content_list, i, cursor)
  337. if len(block) == 0:
  338. continue
  339. children = self.parse_to_tree(text=block, index=index + 1)
  340. level_title_content_list[i]['children'] = children
  341. first_child_idx_in_block = block.lstrip().index(children[0]["content"].lstrip())
  342. if first_child_idx_in_block != 0:
  343. inner_children = self.parse_to_tree(block[:first_child_idx_in_block], index + 1)
  344. level_title_content_list[i]['children'].extend(inner_children)
  345. return level_content_list
  346. def parse(self, text: str):
  347. """
  348. 解析文本
  349. :param text: 文本数据
  350. :return: 解析后数据 {content:段落数据,keywords:[‘段落关键词’],parent_chain:['段落父级链路']}
  351. """
  352. text = text.replace('\r\n', '\n')
  353. text = text.replace('\r', '\n')
  354. text = text.replace("\0", '')
  355. result_tree = self.parse_to_tree(text, 0)
  356. result = result_tree_to_paragraph(result_tree, [], [], self.with_filter)
  357. for e in result:
  358. if len(e['content']) > 4096:
  359. pass
  360. title_list = list(set([row.get('title') for row in result]))
  361. return [item for item in [self.post_reset_paragraph(row, title_list) for row in result] if
  362. 'content' in item and len(item.get('content').strip()) > 0]
  363. def post_reset_paragraph(self, paragraph: Dict, title_list: List[str]):
  364. result = self.content_is_null(paragraph, title_list)
  365. result = self.filter_title_special_characters(result)
  366. result = self.sub_title(result)
  367. return result
  368. @staticmethod
  369. def sub_title(paragraph: Dict):
  370. if 'title' in paragraph:
  371. title = paragraph.get('title')
  372. if len(title) > 255:
  373. return {**paragraph, 'title': title[0:255], 'content': title[255:len(title)] + paragraph.get('content')}
  374. return paragraph
  375. @staticmethod
  376. def content_is_null(paragraph: Dict, title_list: List[str]):
  377. if 'title' in paragraph:
  378. title = paragraph.get('title')
  379. content = paragraph.get('content')
  380. if (content is None or len(content.strip()) == 0) and (title is not None and len(title) > 0):
  381. find = [t for t in title_list if t.__contains__(title) and t != title]
  382. if find:
  383. return {'title': '', 'content': ''}
  384. return {'title': '', 'content': title}
  385. return paragraph
  386. @staticmethod
  387. def filter_title_special_characters(paragraph: Dict):
  388. title = paragraph.get('title') if 'title' in paragraph else ''
  389. for title_special_characters in title_special_characters_list:
  390. title = title.replace(title_special_characters, '')
  391. return {**paragraph,
  392. 'title': title}
  393. title_special_characters_list = ['#', '\n', '\r', '\\s']
  394. default_split_pattern = {
  395. 'md': [re.compile('(?<=^)# .*|(?<=\\n)# .*'),
  396. re.compile('(?<=\\n)(?<!#)## (?!#).*|(?<=^)(?<!#)## (?!#).*'),
  397. re.compile("(?<=\\n)(?<!#)### (?!#).*|(?<=^)(?<!#)### (?!#).*"),
  398. re.compile("(?<=\\n)(?<!#)#### (?!#).*|(?<=^)(?<!#)#### (?!#).*"),
  399. re.compile("(?<=\\n)(?<!#)##### (?!#).*|(?<=^)(?<!#)##### (?!#).*"),
  400. re.compile("(?<=\\n)(?<!#)###### (?!#).*|(?<=^)(?<!#)###### (?!#).*")],
  401. 'default': [re.compile("(?<!\n)\n\n+")]
  402. }
  403. def get_split_model(filename: str, with_filter: bool = False, limit: int = 100000):
  404. """
  405. 根据文件名称获取分段模型
  406. :param limit: 每段大小
  407. :param with_filter: 是否过滤特殊字符
  408. :param filename: 文件名称
  409. :return: 分段模型
  410. """
  411. if filename.endswith(".md"):
  412. pattern_list = default_split_pattern.get('md')
  413. return SplitModel(pattern_list, with_filter=with_filter, limit=limit)
  414. pattern_list = default_split_pattern.get('md')
  415. return SplitModel(pattern_list, with_filter=with_filter, limit=limit)
  416. def to_title_tree_string(result_tree: List):
  417. f = flat(result_tree, [], [])
  418. return "\n│".join(list(map(lambda r: title_tostring(r), list(filter(lambda row: row.get('state') == 'title', f)))))
  419. def title_tostring(title_obj):
  420. f = "│ ".join(list(map(lambda index: " ", range(0, len(title_obj.get("parent_chain"))))))
  421. return f + "├───" + title_obj.get('content')