doc_split_handle.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. # coding=utf-8
  2. """
  3. @project: maxkb
  4. @Author:虎
  5. @file: text_split_handle.py
  6. @date:2024/3/27 18:19
  7. @desc:
  8. """
  9. import io
  10. import os
  11. import re
  12. import traceback
  13. from functools import reduce
  14. from typing import List
  15. import uuid_utils.compat as uuid
  16. from docx import Document, ImagePart
  17. from docx.oxml import ns
  18. from docx.table import Table
  19. from docx.text.paragraph import Paragraph
  20. from common.handle.base_split_handle import BaseSplitHandle
  21. from common.utils.logger import maxkb_logger
  22. from common.utils.split_model import SplitModel
  23. from knowledge.models import File
  24. default_pattern_list = [re.compile('(?<=^)# .*|(?<=\\n)# .*'),
  25. re.compile('(?<=\\n)(?<!#)## (?!#).*|(?<=^)(?<!#)## (?!#).*'),
  26. re.compile("(?<=\\n)(?<!#)### (?!#).*|(?<=^)(?<!#)### (?!#).*"),
  27. re.compile("(?<=\\n)(?<!#)#### (?!#).*|(?<=^)(?<!#)#### (?!#).*"),
  28. re.compile("(?<=\\n)(?<!#)##### (?!#).*|(?<=^)(?<!#)##### (?!#).*"),
  29. re.compile("(?<=\\n)(?<!#)###### (?!#).*|(?<=^)(?<!#)###### (?!#).*")]
  30. old_docx_nsmap = {'v': 'urn:schemas-microsoft-com:vml'}
  31. combine_nsmap = {**ns.nsmap, **old_docx_nsmap}
  32. def image_to_mode(image, doc: Document, images_list, get_image_id):
  33. image_ids = image['get_image_id_handle'](image.get('image'))
  34. for img_id in image_ids: # 获取图片id
  35. part = doc.part.related_parts[img_id] # 根据图片id获取对应的图片
  36. if isinstance(part, ImagePart):
  37. image_uuid = get_image_id(img_id)
  38. if len([i for i in images_list if i.id == image_uuid]) == 0:
  39. image = File(id=image_uuid, file_name=part.filename, meta={'debug': False, 'content': part.blob})
  40. images_list.append(image)
  41. return f'![{part.filename.replace("[", "").replace("]", "")}](./oss/file/{image_uuid})'
  42. return None
  43. return None
  44. def get_paragraph_element_images(paragraph_element, doc: Document, images_list, get_image_id):
  45. images_xpath_list = [(".//pic:pic", lambda img: img.xpath('.//a:blip/@r:embed')),
  46. (".//w:pict", lambda img: img.xpath('.//v:imagedata/@r:id', namespaces=combine_nsmap))]
  47. images = []
  48. for images_xpath, get_image_id_handle in images_xpath_list:
  49. try:
  50. _images = paragraph_element.xpath(images_xpath)
  51. if _images is not None and len(_images) > 0:
  52. for image in _images:
  53. images.append({'image': image, 'get_image_id_handle': get_image_id_handle})
  54. except Exception as e:
  55. pass
  56. return images
  57. def images_to_string(images, doc: Document, images_list, get_image_id):
  58. return "".join(
  59. [item for item in [image_to_mode(image, doc, images_list, get_image_id) for image in images] if
  60. item is not None])
  61. def get_paragraph_element_txt(paragraph_element, doc: Document, images_list, get_image_id):
  62. try:
  63. images = get_paragraph_element_images(paragraph_element, doc, images_list, get_image_id)
  64. if len(images) > 0:
  65. return images_to_string(images, doc, images_list, get_image_id)
  66. elif paragraph_element.text is not None:
  67. return paragraph_element.text
  68. return ""
  69. except Exception as e:
  70. maxkb_logger.error(f'Error getting paragraph element text: {e}')
  71. return ""
  72. def get_paragraph_txt(paragraph: Paragraph, doc: Document, images_list, get_image_id):
  73. try:
  74. return "".join([get_paragraph_element_txt(e, doc, images_list, get_image_id) for e in paragraph._element])
  75. except Exception as e:
  76. return ""
  77. def get_cell_text(cell, doc: Document, images_list, get_image_id):
  78. try:
  79. return "".join(
  80. [get_paragraph_txt(paragraph, doc, images_list, get_image_id) for paragraph in cell.paragraphs]).replace(
  81. "\n", '</br>')
  82. except Exception as e:
  83. return ""
  84. def get_image_id_func():
  85. image_map = {}
  86. def get_image_id(image_id):
  87. _v = image_map.get(image_id)
  88. if _v is None:
  89. image_map[image_id] = uuid.uuid7()
  90. return image_map.get(image_id)
  91. return _v
  92. return get_image_id
  93. title_font_list = [
  94. [36, 100],
  95. [26, 36],
  96. [24, 26],
  97. [22, 24],
  98. [18, 22],
  99. [16, 18]
  100. ]
  101. def get_title_level(paragraph: Paragraph):
  102. try:
  103. if paragraph.style is not None:
  104. psn = paragraph.style.name
  105. if psn.startswith('Heading') or psn.startswith('TOC 标题') or psn.startswith('标题'):
  106. return int(psn.replace("Heading ", '').replace('TOC 标题', '').replace('标题',
  107. ''))
  108. if len(paragraph.runs) >= 1:
  109. font_size = paragraph.runs[0].font.size
  110. pt = font_size.pt
  111. if pt >= 16:
  112. for _value, index in zip(title_font_list, range(len(title_font_list))):
  113. if pt >= _value[0] and pt < _value[1] and any([run.font.bold for run in paragraph.runs]):
  114. return index + 1
  115. except Exception as e:
  116. pass
  117. return None
  118. class DocSplitHandle(BaseSplitHandle):
  119. @staticmethod
  120. def paragraph_to_md(paragraph: Paragraph, doc: Document, images_list, get_image_id):
  121. try:
  122. title_level = get_title_level(paragraph)
  123. if title_level is not None:
  124. title = "".join(["#" for i in range(title_level)]) + " " + paragraph.text
  125. images = reduce(lambda x, y: [*x, *y],
  126. [get_paragraph_element_images(e, doc, images_list, get_image_id) for e in
  127. paragraph._element],
  128. [])
  129. if len(images) > 0:
  130. return title + '\n' + images_to_string(images, doc, images_list, get_image_id) if len(
  131. paragraph.text) > 0 else images_to_string(images, doc, images_list, get_image_id)
  132. return title
  133. except Exception as e:
  134. maxkb_logger.error(f"Error processing DOC file: {e}, {traceback.format_exc()}")
  135. return paragraph.text
  136. return get_paragraph_txt(paragraph, doc, images_list, get_image_id)
  137. @staticmethod
  138. def table_to_md(table, doc: Document, images_list, get_image_id):
  139. rows = table.rows
  140. # 创建 Markdown 格式的表格
  141. md_table = '| ' + ' | '.join(
  142. [get_cell_text(cell, doc, images_list, get_image_id) for cell in rows[0].cells]) + ' |\n'
  143. md_table += '| ' + ' | '.join(['---' for i in range(len(rows[0].cells))]) + ' |\n'
  144. for row in rows[1:]:
  145. md_table += '| ' + ' | '.join(
  146. [get_cell_text(cell, doc, images_list, get_image_id) for cell in row.cells]) + ' |\n'
  147. return md_table
  148. def to_md(self, doc, images_list, get_image_id):
  149. elements = []
  150. for element in doc.element.body:
  151. tag = str(element.tag)
  152. if tag.endswith('tbl'):
  153. # 处理表格
  154. table = Table(element, doc)
  155. elements.append(table)
  156. elif tag.endswith('p'):
  157. # 处理段落
  158. paragraph = Paragraph(element, doc)
  159. elements.append(paragraph)
  160. return "\n".join(
  161. [self.paragraph_to_md(element, doc, images_list, get_image_id) if isinstance(element,
  162. Paragraph) else self.table_to_md(
  163. element,
  164. doc,
  165. images_list, get_image_id)
  166. for element
  167. in elements])
  168. def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image):
  169. file_name = os.path.basename(file.name)
  170. try:
  171. if type(limit) is str:
  172. limit = int(limit)
  173. if type(with_filter) is str:
  174. with_filter = with_filter.lower() == 'true'
  175. image_list = []
  176. buffer = get_buffer(file)
  177. doc = Document(io.BytesIO(buffer))
  178. content = self.to_md(doc, image_list, get_image_id_func())
  179. if len(image_list) > 0:
  180. save_image(image_list)
  181. if pattern_list is not None and len(pattern_list) > 0:
  182. split_model = SplitModel(pattern_list, with_filter, limit)
  183. else:
  184. split_model = SplitModel(default_pattern_list, with_filter=with_filter, limit=limit)
  185. except BaseException as e:
  186. maxkb_logger.error(f"Error processing XLSX file {file.name}: {e}, {traceback.format_exc()}")
  187. return {
  188. 'name': file_name,
  189. 'content': []
  190. }
  191. return {
  192. 'name': file_name,
  193. 'content': split_model.parse(content)
  194. }
  195. def support(self, file, get_buffer):
  196. file_name: str = file.name.lower()
  197. if file_name.endswith(".docx") or file_name.endswith(".doc") or file_name.endswith(
  198. ".DOC") or file_name.endswith(".DOCX"):
  199. return True
  200. return False
  201. def get_content(self, file, save_image):
  202. try:
  203. image_list = []
  204. buffer = file.read()
  205. doc = Document(io.BytesIO(buffer))
  206. content = self.to_md(doc, image_list, get_image_id_func())
  207. if len(image_list) > 0:
  208. save_image(image_list)
  209. return content
  210. except BaseException as e:
  211. traceback.print_exception(e)
  212. return f'{e}'