zip_split_handle.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. # coding=utf-8
  2. """
  3. @project: maxkb
  4. @Author:虎
  5. @file: text_split_handle.py
  6. @date:2024/3/27 18:19
  7. @desc:
  8. """
  9. import io
  10. import os
  11. import re
  12. import zipfile
  13. from typing import List
  14. from urllib.parse import urljoin
  15. import uuid_utils.compat as uuid
  16. from charset_normalizer import detect
  17. from django.utils.translation import gettext_lazy as _
  18. from common.handle.base_split_handle import BaseSplitHandle
  19. from common.handle.impl.text.csv_split_handle import CsvSplitHandle
  20. from common.handle.impl.text.doc_split_handle import DocSplitHandle
  21. from common.handle.impl.text.html_split_handle import HTMLSplitHandle
  22. from common.handle.impl.text.pdf_split_handle import PdfSplitHandle
  23. from common.handle.impl.text.text_split_handle import TextSplitHandle
  24. from common.handle.impl.text.xls_split_handle import XlsSplitHandle
  25. from common.handle.impl.text.xlsx_split_handle import XlsxSplitHandle
  26. from common.utils.common import parse_md_image
  27. from knowledge.models import File
  28. class FileBufferHandle:
  29. buffer = None
  30. def get_buffer(self, file):
  31. if self.buffer is None:
  32. self.buffer = file.read()
  33. return self.buffer
  34. default_split_handle = TextSplitHandle()
  35. split_handles = [
  36. HTMLSplitHandle(),
  37. DocSplitHandle(),
  38. PdfSplitHandle(),
  39. XlsxSplitHandle(),
  40. XlsSplitHandle(),
  41. CsvSplitHandle(),
  42. default_split_handle
  43. ]
  44. def file_to_paragraph(file, pattern_list: List, with_filter: bool, limit: int, save_inner_image):
  45. get_buffer = FileBufferHandle().get_buffer
  46. for split_handle in split_handles:
  47. if split_handle.support(file, get_buffer):
  48. return split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, save_inner_image)
  49. raise Exception(_('Unsupported file format'))
  50. def is_valid_uuid(uuid_str: str):
  51. try:
  52. uuid.UUID(uuid_str)
  53. except ValueError:
  54. return False
  55. return True
  56. def get_image_list(result_list: list, zip_files: List[str]):
  57. image_file_list = []
  58. for result in result_list:
  59. for p in result.get('content', []):
  60. content: str = p.get('content', '')
  61. image_list = parse_md_image(content)
  62. for image in image_list:
  63. search = re.search("\(.*\)", image)
  64. if search:
  65. new_image_id = str(uuid.uuid7())
  66. source_image_path = search.group().replace('(', '').replace(')', '')
  67. source_image_path = source_image_path.strip().split(" ")[0]
  68. image_path = urljoin(result.get('name'), '.' + source_image_path if source_image_path.startswith(
  69. '/') else source_image_path)
  70. if not zip_files.__contains__(image_path):
  71. continue
  72. if image_path.startswith('oss/file/') or image_path.startswith('oss/image/'):
  73. image_id = image_path.replace('oss/file/', '').replace('oss/image/', '')
  74. if is_valid_uuid(image_id):
  75. image_file_list.append({'source_file': image_path,
  76. 'image_id': image_id})
  77. else:
  78. image_file_list.append({'source_file': image_path,
  79. 'image_id': new_image_id})
  80. content = content.replace(source_image_path, f'./oss/file/{new_image_id}')
  81. p['content'] = content
  82. else:
  83. image_file_list.append({'source_file': image_path,
  84. 'image_id': new_image_id})
  85. content = content.replace(source_image_path, f'./oss/file/{new_image_id}')
  86. p['content'] = content
  87. return image_file_list
  88. def get_image_list_by_content(name: str, content: str, zip_files: List[str]):
  89. image_file_list = []
  90. image_list = parse_md_image(content)
  91. for image in image_list:
  92. search = re.search("\(.*\)", image)
  93. if search:
  94. new_image_id = str(uuid.uuid7())
  95. source_image_path = search.group().replace('(', '').replace(')', '')
  96. source_image_path = source_image_path.strip().split(" ")[0]
  97. image_path = urljoin(name, '.' + source_image_path if source_image_path.startswith(
  98. '/') else source_image_path)
  99. if not zip_files.__contains__(image_path):
  100. continue
  101. if image_path.startswith('oss/file/') or image_path.startswith('oss/image/'):
  102. image_id = image_path.replace('oss/file/', '').replace('oss/image/', '')
  103. if is_valid_uuid(image_id):
  104. image_file_list.append({'source_file': image_path,
  105. 'image_id': image_id})
  106. else:
  107. image_file_list.append({'source_file': image_path,
  108. 'image_id': new_image_id})
  109. content = content.replace(source_image_path, f'./oss/file/{new_image_id}')
  110. else:
  111. image_file_list.append({'source_file': image_path,
  112. 'image_id': new_image_id})
  113. content = content.replace(source_image_path, f'./oss/file/{new_image_id}')
  114. return image_file_list, content
  115. def get_file_name(file_name):
  116. try:
  117. file_name_code = file_name.encode('cp437')
  118. charset = detect(file_name_code)['encoding']
  119. return file_name_code.decode(charset)
  120. except Exception as e:
  121. return file_name
  122. def filter_image_file(result_list: list, image_list):
  123. image_source_file_list = [image.get('source_file') for image in image_list]
  124. return [r for r in result_list if not image_source_file_list.__contains__(r.get('name', ''))]
  125. class ZipSplitHandle(BaseSplitHandle):
  126. def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image):
  127. if type(limit) is str:
  128. limit = int(limit)
  129. if type(with_filter) is str:
  130. with_filter = with_filter.lower() == 'true'
  131. buffer = get_buffer(file)
  132. bytes_io = io.BytesIO(buffer)
  133. result = []
  134. # 打开zip文件
  135. with zipfile.ZipFile(bytes_io, 'r') as zip_ref:
  136. # 获取压缩包中的文件名列表
  137. files = zip_ref.namelist()
  138. # 读取压缩包中的文件内容
  139. for file in files:
  140. if file.endswith('/') or file.startswith('__MACOSX'):
  141. continue
  142. with zip_ref.open(file) as f:
  143. # 对文件内容进行处理
  144. try:
  145. # 处理一下文件名
  146. f.name = get_file_name(f.name)
  147. value = file_to_paragraph(f, pattern_list, with_filter, limit, save_image)
  148. if isinstance(value, list):
  149. result = [*result, *value]
  150. else:
  151. result.append(value)
  152. except Exception:
  153. pass
  154. image_list = get_image_list(result, files)
  155. result = filter_image_file(result, image_list)
  156. image_mode_list = []
  157. for image in image_list:
  158. with zip_ref.open(image.get('source_file')) as f:
  159. i = File(
  160. id=image.get('image_id'),
  161. file_name=os.path.basename(image.get('source_file')),
  162. meta={'debug': False, 'content': f.read()} # 这里的content是二进制数据
  163. )
  164. image_mode_list.append(i)
  165. save_image(image_mode_list)
  166. return result
  167. def support(self, file, get_buffer):
  168. file_name: str = file.name.lower()
  169. if file_name.endswith(".zip") or file_name.endswith(".ZIP"):
  170. return True
  171. return False
  172. def get_content(self, file, save_image):
  173. """
  174. 从 zip 中提取并返回拼接的 md 文本,同时收集并保存内嵌图片(通过 save_image 回调)。
  175. 使用 posixpath 来正确处理 zip 内部的路径拼接与规范化。
  176. """
  177. buffer = file.read() if hasattr(file, 'read') else None
  178. bytes_io = io.BytesIO(buffer) if buffer is not None else io.BytesIO(file)
  179. image_list = []
  180. content_parts = []
  181. with zipfile.ZipFile(bytes_io, 'r') as zip_ref:
  182. files = zip_ref.namelist()
  183. file_content_list = []
  184. for inner_name in files:
  185. if inner_name.endswith('/') or inner_name.startswith('__MACOSX'):
  186. continue
  187. with zip_ref.open(inner_name) as zf:
  188. try:
  189. real_name = get_file_name(zf.name)
  190. except Exception:
  191. real_name = zf.name
  192. # 为 split_handle 提供可重复读取的 file-like 对象
  193. zf.name = real_name
  194. get_buffer = FileBufferHandle().get_buffer
  195. for split_handle in split_handles:
  196. if split_handle.support(zf, get_buffer):
  197. row = get_buffer(zf)
  198. md_text = split_handle.get_content(io.BytesIO(row), save_image)
  199. file_content_list.append({'content': md_text, 'name': real_name})
  200. break
  201. for file_content in file_content_list:
  202. _image_list, content = get_image_list_by_content(file_content.get('name'), file_content.get("content"),
  203. files)
  204. content_parts.append(content)
  205. for image in _image_list:
  206. image_list.append(image)
  207. # 将收集到的图片通过回调保存(一次性)
  208. if image_list:
  209. image_mode_list = []
  210. for image in image_list:
  211. with zip_ref.open(image.get('source_file')) as f:
  212. i = File(
  213. id=image.get('image_id'),
  214. file_name=os.path.basename(image.get('source_file')),
  215. meta={'debug': False, 'content': f.read()} # 这里的content是二进制数据
  216. )
  217. image_mode_list.append(i)
  218. save_image(image_mode_list)
  219. return '\n\n'.join(content_parts)