tool_utils.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. """
  4. 工具函数模块
  5. 提供常用的工具函数和辅助类
  6. """
  7. import json
  8. import datetime
  9. from typing import Any, Dict, List, Optional, Union
  10. import hashlib
  11. import uuid
  12. import re
  13. class DateTimeEncoder(json.JSONEncoder):
  14. """
  15. 日期时间JSON编码器
  16. 用于将datetime对象序列化为JSON字符串
  17. """
  18. def default(self, obj):
  19. if isinstance(obj, datetime.datetime):
  20. return obj.isoformat()
  21. elif isinstance(obj, datetime.date):
  22. return obj.isoformat()
  23. elif isinstance(obj, datetime.time):
  24. return obj.isoformat()
  25. elif hasattr(obj, '__dict__'):
  26. return obj.__dict__
  27. return super().default(obj)
  28. class ToolUtils:
  29. """工具类集合"""
  30. @staticmethod
  31. def generate_uuid() -> str:
  32. """生成UUID字符串"""
  33. return str(uuid.uuid4())
  34. @staticmethod
  35. def generate_trace_id() -> str:
  36. """生成追踪ID"""
  37. return str(uuid.uuid4()).replace('-', '')[:16]
  38. @staticmethod
  39. def hash_string(text: str, algorithm: str = 'md5') -> str:
  40. """
  41. 计算字符串哈希值
  42. Args:
  43. text: 要哈希的文本
  44. algorithm: 哈希算法 ('md5', 'sha1', 'sha256')
  45. Returns:
  46. 哈希值字符串
  47. """
  48. if algorithm == 'md5':
  49. return hashlib.md5(text.encode('utf-8')).hexdigest()
  50. elif algorithm == 'sha1':
  51. return hashlib.sha1(text.encode('utf-8')).hexdigest()
  52. elif algorithm == 'sha256':
  53. return hashlib.sha256(text.encode('utf-8')).hexdigest()
  54. else:
  55. raise ValueError(f"Unsupported algorithm: {algorithm}")
  56. @staticmethod
  57. def clean_text(text: str) -> str:
  58. """
  59. 清理文本,移除多余空白字符
  60. Args:
  61. text: 要清理的文本
  62. Returns:
  63. 清理后的文本
  64. """
  65. # 移除多余的空白字符
  66. text = re.sub(r'\s+', ' ', text.strip())
  67. return text
  68. @staticmethod
  69. def truncate_text(text: str, max_length: int = 100, suffix: str = "...") -> str:
  70. """
  71. 截断文本
  72. Args:
  73. text: 要截断的文本
  74. max_length: 最大长度
  75. suffix: 截断后缀
  76. Returns:
  77. 截断后的文本
  78. """
  79. if len(text) <= max_length:
  80. return text
  81. return text[:max_length - len(suffix)] + suffix
  82. @staticmethod
  83. def extract_emails(text: str) -> List[str]:
  84. """
  85. 从文本中提取邮箱地址
  86. Args:
  87. text: 要分析的文本
  88. Returns:
  89. 邮箱地址列表
  90. """
  91. pattern = r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b'
  92. return re.findall(pattern, text)
  93. @staticmethod
  94. def extract_phone_numbers(text: str) -> List[str]:
  95. """
  96. 从文本中提取手机号码
  97. Args:
  98. text: 要分析的文本
  99. Returns:
  100. 手机号码列表
  101. """
  102. # 中国大陆手机号码模式
  103. pattern = r'1[3-9]\d{9}'
  104. return re.findall(pattern, text)
  105. @staticmethod
  106. def format_file_size(size_bytes: int) -> str:
  107. """
  108. 格式化文件大小
  109. Args:
  110. size_bytes: 字节数
  111. Returns:
  112. 格式化后的文件大小字符串
  113. """
  114. if size_bytes == 0:
  115. return "0B"
  116. size_names = ["B", "KB", "MB", "GB", "TB"]
  117. i = 0
  118. while size_bytes >= 1024 and i < len(size_names) - 1:
  119. size_bytes /= 1024.0
  120. i += 1
  121. return f"{size_bytes:.1f}{size_names[i]}"
  122. @staticmethod
  123. def deep_merge_dict(dict1: Dict[str, Any], dict2: Dict[str, Any]) -> Dict[str, Any]:
  124. """
  125. 深度合并字典
  126. Args:
  127. dict1: 第一个字典
  128. dict2: 第二个字典
  129. Returns:
  130. 合并后的字典
  131. """
  132. result = dict1.copy()
  133. for key, value in dict2.items():
  134. if key in result and isinstance(result[key], dict) and isinstance(value, dict):
  135. result[key] = ToolUtils.deep_merge_dict(result[key], value)
  136. else:
  137. result[key] = value
  138. return result
  139. @staticmethod
  140. def safe_get_nested(data: Union[Dict, List], path: str, default: Any = None) -> Any:
  141. """
  142. 安全获取嵌套数据
  143. Args:
  144. data: 数据对象
  145. path: 路径,用点号分隔 (例如: 'user.profile.name')
  146. default: 默认值
  147. Returns:
  148. 获取到的值或默认值
  149. """
  150. keys = path.split('.')
  151. current = data
  152. try:
  153. for key in keys:
  154. if isinstance(current, dict):
  155. current = current[key]
  156. elif isinstance(current, list):
  157. current = current[int(key)]
  158. else:
  159. return default
  160. return current
  161. except (KeyError, IndexError, TypeError, ValueError):
  162. return default
  163. @staticmethod
  164. def chunk_list(lst: List[Any], chunk_size: int) -> List[List[Any]]:
  165. """
  166. 将列表分块
  167. Args:
  168. lst: 要分块的列表
  169. chunk_size: 块大小
  170. Returns:
  171. 分块后的列表
  172. """
  173. return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
  174. @staticmethod
  175. def flatten_dict(d: Dict[str, Any], parent_key: str = '', sep: str = '.') -> Dict[str, Any]:
  176. """
  177. 扁平化字典
  178. Args:
  179. d: 要扁平化的字典
  180. parent_key: 父键名
  181. sep: 分隔符
  182. Returns:
  183. 扁平化后的字典
  184. """
  185. items = []
  186. for k, v in d.items():
  187. new_key = f"{parent_key}{sep}{k}" if parent_key else k
  188. if isinstance(v, dict):
  189. items.extend(ToolUtils.flatten_dict(v, new_key, sep=sep).items())
  190. else:
  191. items.append((new_key, v))
  192. return dict(items)
  193. # 便捷函数
  194. def generate_uuid() -> str:
  195. """生成UUID字符串(便捷函数)"""
  196. return ToolUtils.generate_uuid()
  197. def generate_trace_id() -> str:
  198. """生成追踪ID(便捷函数)"""
  199. return ToolUtils.generate_trace_id()
  200. def clean_text(text: str) -> str:
  201. """清理文本(便捷函数)"""
  202. return ToolUtils.clean_text(text)
  203. def format_file_size(size_bytes: int) -> str:
  204. """格式化文件大小(便捷函数)"""
  205. return ToolUtils.format_file_size(size_bytes)
  206. # 导出的类和函数
  207. __all__ = [
  208. "DateTimeEncoder",
  209. "ToolUtils",
  210. "generate_uuid",
  211. "generate_trace_id",
  212. "clean_text",
  213. "format_file_size"
  214. ]