common.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. # coding=utf-8
  2. """
  3. @project: MaxKB
  4. @Author:虎虎
  5. @file: common.py
  6. @date:2025/4/14 18:23
  7. @desc:
  8. """
  9. import hashlib
  10. import io
  11. import mimetypes
  12. import pickle
  13. import random
  14. import re
  15. import shutil
  16. import uuid
  17. from functools import reduce
  18. from typing import List, Dict
  19. from django.core.files.uploadedfile import InMemoryUploadedFile
  20. from django.db.models import QuerySet
  21. from django.utils.translation import gettext as _
  22. from pydub import AudioSegment
  23. from ..database_model_manage.database_model_manage import DatabaseModelManage
  24. from ..exception.app_exception import AppApiException
  25. def password_encrypt(row_password):
  26. """
  27. 密码 md5加密
  28. :param row_password: 密码
  29. :return: 加密后密码
  30. """
  31. md5 = hashlib.md5() # 2,实例化md5() 方法
  32. md5.update(row_password.encode()) # 3,对字符串的字节类型加密
  33. result = md5.hexdigest() # 4,加密
  34. return result
  35. def group_by(list_source: List, key):
  36. """
  37. 將數組分組
  38. :param list_source: 需要分組的數組
  39. :param key: 分組函數
  40. :return: key->[]
  41. """
  42. result = {}
  43. for e in list_source:
  44. k = key(e)
  45. array = result.get(k) if k in result else []
  46. array.append(e)
  47. result[k] = array
  48. return result
  49. SAFE_CHAR_SET = (
  50. [chr(i) for i in range(65, 91) if chr(i) not in {'I', 'O'}] + # 大写字母 A-H, J-N, P-Z
  51. [chr(i) for i in range(97, 123) if chr(i) not in {'i', 'l', 'o'}] + # 小写字母 a-h, j-n, p-z
  52. [str(i) for i in range(10) if str(i) not in {'0', '1', '7'}] # 数字 2-6, 8-9
  53. )
  54. def get_random_chars(number=4):
  55. if number <= 0:
  56. return ""
  57. return ''.join(random.choices(SAFE_CHAR_SET, k=number))
  58. def encryption(message: str):
  59. """
  60. 加密敏感字段数据 加密方式是 如果密码是 1234567890 那么给前端则是 123******890
  61. :param message:
  62. :return:
  63. """
  64. if not message: # 处理空字符串情况
  65. return "***************"
  66. max_pre_len = 8
  67. max_post_len = 4
  68. message_len = len(message)
  69. pre_len = int(message_len / 5 * 2)
  70. post_len = int(message_len / 5 * 1)
  71. pre_str = "".join([message[index] for index in
  72. range(0, max_pre_len if pre_len > max_pre_len else 1 if pre_len <= 0 else int(pre_len))])
  73. end_str = "".join(
  74. [message[index] for index in
  75. range(message_len - (int(post_len) if pre_len < max_post_len else max_post_len), message_len)])
  76. content = "***************"
  77. return pre_str + content + end_str
  78. def _remove_empty_lines(text):
  79. if not isinstance(text, str):
  80. raise AppApiException(500, _('Text-to-speech node, the text content must be of string type'))
  81. if not text:
  82. raise AppApiException(500, _('Text-to-speech node, the text content cannot be empty'))
  83. result = '\n'.join(line for line in text.split('\n') if line.strip())
  84. return markdown_to_plain_text(result)
  85. def markdown_to_plain_text(md: str) -> str:
  86. # 移除图片 ![alt](url)
  87. text = re.sub(r'!\[.*?\]\(.*?\)', '', md)
  88. # 移除链接 [text](url)
  89. text = re.sub(r'\[([^\]]+)\]\([^)]+\)', r'\1', text)
  90. # 移除 Markdown 标题符号 (#, ##, ###)
  91. text = re.sub(r'^#{1,6}\s+', '', text, flags=re.MULTILINE)
  92. # 移除加粗 **text** 或 __text__
  93. text = re.sub(r'\*\*(.*?)\*\*', r'\1', text)
  94. text = re.sub(r'__(.*?)__', r'\1', text)
  95. # 移除斜体 *text* 或 _text_
  96. text = re.sub(r'\*(.*?)\*', r'\1', text)
  97. text = re.sub(r'_(.*?)_', r'\1', text)
  98. # 移除行内代码 `code`
  99. text = re.sub(r'`(.*?)`', r'\1', text)
  100. # 移除代码块 ```code```
  101. text = re.sub(r'```[\s\S]*?```', '', text)
  102. # 移除多余的换行符
  103. text = re.sub(r'\n{2,}', '\n', text)
  104. # 使用正则表达式去除所有 HTML 标签
  105. text = re.sub(r'<[^>]+>', '', text)
  106. # 先移除特定媒体标签(优先级高于通用HTML标签移除)
  107. text = re.sub(r'<(?:audio|video)(?:\s+[^>]*)?>[\s\S]*?(?:</(?:audio|video)>)?', '', text, flags=re.IGNORECASE)
  108. text = re.sub(r'<img[^>]*>', '', text) # 匹配图片标签
  109. # 去除多余的空白字符(包括换行符、制表符等)
  110. text = re.sub(r'\s+', ' ', text)
  111. # 去除表单渲染
  112. re.sub(r'<form_rander>[\s\S]*?<\/form_rander>', '', text)
  113. # 去除首尾空格
  114. text = text.strip()
  115. return text
  116. def get_file_content(path):
  117. with open(path, "r", encoding='utf-8') as file:
  118. content = file.read()
  119. return content
  120. def sub_array(array: List, item_num=10):
  121. result = []
  122. temp = []
  123. for item in array:
  124. temp.append(item)
  125. if len(temp) >= item_num:
  126. result.append(temp)
  127. temp = []
  128. if len(temp) > 0:
  129. result.append(temp)
  130. return result
  131. def bytes_to_uploaded_file(file_bytes, file_name="file.txt"):
  132. content_type, _ = mimetypes.guess_type(file_name)
  133. if content_type is None:
  134. # 如果未能识别,设置为默认的二进制文件类型
  135. content_type = "application/octet-stream"
  136. # 创建一个内存中的字节流对象
  137. file_stream = io.BytesIO(file_bytes)
  138. # 获取文件大小
  139. file_size = len(file_bytes)
  140. # 创建 InMemoryUploadedFile 对象
  141. uploaded_file = InMemoryUploadedFile(
  142. file=file_stream,
  143. field_name=None,
  144. name=file_name,
  145. content_type=content_type,
  146. size=file_size,
  147. charset=None,
  148. )
  149. return uploaded_file
  150. def any_to_amr(any_path, amr_path):
  151. """
  152. 把任意格式转成amr文件
  153. """
  154. if any_path.endswith(".amr"):
  155. shutil.copy2(any_path, amr_path)
  156. return
  157. if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
  158. raise NotImplementedError("Not support file type: {}".format(any_path))
  159. audio = AudioSegment.from_file(any_path)
  160. audio = audio.set_frame_rate(8000) # only support 8000
  161. audio.export(amr_path, format="amr")
  162. return audio.duration_seconds * 1000
  163. def any_to_mp3(any_path, mp3_path):
  164. """
  165. 把任意格式转成mp3文件
  166. """
  167. if any_path.endswith(".mp3"):
  168. shutil.copy2(any_path, mp3_path)
  169. return
  170. if any_path.endswith(".sil") or any_path.endswith(".silk") or any_path.endswith(".slk"):
  171. sil_to_wav(any_path, any_path)
  172. any_path = mp3_path
  173. audio = AudioSegment.from_file(any_path)
  174. audio = audio.set_frame_rate(16000)
  175. audio.export(mp3_path, format="mp3")
  176. def sil_to_wav(silk_path, wav_path, rate: int = 24000):
  177. """
  178. silk 文件转 wav
  179. """
  180. try:
  181. import pysilk
  182. except ImportError:
  183. raise AppApiException("import pysilk failed, wechaty voice message will not be supported.")
  184. wav_data = pysilk.decode_file(silk_path, to_wav=True, sample_rate=rate)
  185. with open(wav_path, "wb") as f:
  186. f.write(wav_data)
  187. def split_and_transcribe(file_path, model, max_segment_length_ms=59000, audio_format="mp3"):
  188. audio_data = AudioSegment.from_file(file_path, format=audio_format)
  189. audio_length_ms = len(audio_data)
  190. if audio_length_ms <= max_segment_length_ms:
  191. return model.speech_to_text(io.BytesIO(audio_data.export(format=audio_format).read()))
  192. full_text = []
  193. for start_ms in range(0, audio_length_ms, max_segment_length_ms):
  194. end_ms = min(audio_length_ms, start_ms + max_segment_length_ms)
  195. segment = audio_data[start_ms:end_ms]
  196. text = model.speech_to_text(io.BytesIO(segment.export(format=audio_format).read()))
  197. if isinstance(text, str):
  198. full_text.append(text)
  199. return ' '.join(full_text)
  200. def query_params_to_single_dict(query_params: Dict):
  201. return reduce(lambda x, y: {**x, **y}, list(
  202. filter(lambda item: item is not None, [({key: value} if value is not None and len(value) > 0 else None) for
  203. key, value in
  204. query_params.items()])), {})
  205. def valid_license(model=None, count=None, message=None):
  206. def inner(func):
  207. def run(*args, **kwargs):
  208. is_license_valid = DatabaseModelManage.get_model('license_is_valid')
  209. is_license_valid = is_license_valid() if is_license_valid() is not None else False
  210. record_count = QuerySet(model).count()
  211. if not is_license_valid and record_count >= count:
  212. error_message = message or _(
  213. 'Limit {count} exceeded, please contact us (https://fit2cloud.com/).').format(
  214. count=count)
  215. raise AppApiException(400, error_message)
  216. return func(*args, **kwargs)
  217. return run
  218. return inner
  219. def post(post_function):
  220. def inner(func):
  221. def run(*args, **kwargs):
  222. result = func(*args, **kwargs)
  223. return post_function(*result)
  224. return run
  225. return inner
  226. def parse_md_image(content: str):
  227. matches = re.finditer("!\[.*?\]\(.*?\)", content)
  228. image_list = [match.group() for match in matches]
  229. return image_list
  230. def bulk_create_in_batches(model, data, batch_size=1000):
  231. if len(data) == 0:
  232. return
  233. for i in range(0, len(data), batch_size):
  234. batch = data[i:i + batch_size]
  235. model.objects.bulk_create(batch)
  236. def get_sha256_hash(_v: str | bytes):
  237. sha256 = hashlib.sha256()
  238. if isinstance(_v, str):
  239. sha256.update(_v.encode())
  240. else:
  241. sha256.update(_v)
  242. return sha256.hexdigest()
  243. ALLOWED_CLASSES = {
  244. ("builtins", "dict"),
  245. ('uuid', 'UUID'),
  246. ("application.serializers.application", "MKInstance"),
  247. ("tools.serializers.tool", "ToolInstance"),
  248. ("knowledge.serializers.knowledge_workflow", "KBWFInstance")
  249. }
  250. class RestrictedUnpickler(pickle.Unpickler):
  251. def find_class(self, module, name):
  252. if (module, name) in ALLOWED_CLASSES:
  253. return super().find_class(module, name)
  254. raise pickle.UnpicklingError("global '%s.%s' is forbidden" %
  255. (module, name))
  256. def restricted_loads(s):
  257. """Helper function analogous to pickle.loads()."""
  258. return RestrictedUnpickler(io.BytesIO(s)).load()
  259. def flat_map(array: List[List]):
  260. """
  261. 将二位数组转为一维数组
  262. :param array: 二维数组
  263. :return: 一维数组
  264. """
  265. result = []
  266. for e in array:
  267. result += e
  268. return result
  269. def parse_image(content: str):
  270. matches = re.finditer("!\[.*?\]\(\.\/oss\/(image|file)\/.*?\)", content)
  271. image_list = [match.group() for match in matches]
  272. return image_list
  273. def generate_uuid(tag: str):
  274. return str(uuid.uuid5(uuid.NAMESPACE_DNS, tag))
  275. def filter_workspace(query_list):
  276. return [q for q in query_list if q.name != "workspace_id"]
  277. def filter_special_character(_str):
  278. """
  279. 过滤特殊字符
  280. """
  281. s_list = ["\\u0000"]
  282. for t in s_list:
  283. _str = _str.replace(t, '')
  284. return _str
  285. def is_valid_uuid(uuid_string):
  286. """判断字符串是否为有效的UUID"""
  287. try:
  288. uuid_obj = uuid.UUID(uuid_string)
  289. return str(uuid_obj) == uuid_string
  290. except ValueError:
  291. return False