#!/usr/bin/env python # -*- coding: utf-8 -*- """ 工具函数模块 提供常用的工具函数和辅助类 """ import json import datetime from typing import Any, Dict, List, Optional, Union import hashlib import uuid import re class DateTimeEncoder(json.JSONEncoder): """ 日期时间JSON编码器 用于将datetime对象序列化为JSON字符串 """ def default(self, obj): if isinstance(obj, datetime.datetime): return obj.isoformat() elif isinstance(obj, datetime.date): return obj.isoformat() elif isinstance(obj, datetime.time): return obj.isoformat() elif hasattr(obj, '__dict__'): return obj.__dict__ return super().default(obj) class ToolUtils: """工具类集合""" @staticmethod def generate_uuid() -> str: """生成UUID字符串""" return str(uuid.uuid4()) @staticmethod def generate_trace_id() -> str: """生成追踪ID""" return str(uuid.uuid4()).replace('-', '')[:16] @staticmethod def hash_string(text: str, algorithm: str = 'md5') -> str: """ 计算字符串哈希值 Args: text: 要哈希的文本 algorithm: 哈希算法 ('md5', 'sha1', 'sha256') Returns: 哈希值字符串 """ if algorithm == 'md5': return hashlib.md5(text.encode('utf-8')).hexdigest() elif algorithm == 'sha1': return hashlib.sha1(text.encode('utf-8')).hexdigest() elif algorithm == 'sha256': return hashlib.sha256(text.encode('utf-8')).hexdigest() else: raise ValueError(f"Unsupported algorithm: {algorithm}") @staticmethod def clean_text(text: str) -> str: """ 清理文本,移除多余空白字符 Args: text: 要清理的文本 Returns: 清理后的文本 """ # 移除多余的空白字符 text = re.sub(r'\s+', ' ', text.strip()) return text @staticmethod def truncate_text(text: str, max_length: int = 100, suffix: str = "...") -> str: """ 截断文本 Args: text: 要截断的文本 max_length: 最大长度 suffix: 截断后缀 Returns: 截断后的文本 """ if len(text) <= max_length: return text return text[:max_length - len(suffix)] + suffix @staticmethod def extract_emails(text: str) -> List[str]: """ 从文本中提取邮箱地址 Args: text: 要分析的文本 Returns: 邮箱地址列表 """ pattern = r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b' return re.findall(pattern, text) @staticmethod def extract_phone_numbers(text: str) -> List[str]: """ 从文本中提取手机号码 Args: text: 要分析的文本 Returns: 手机号码列表 """ # 中国大陆手机号码模式 pattern = r'1[3-9]\d{9}' return re.findall(pattern, text) @staticmethod def format_file_size(size_bytes: int) -> str: """ 格式化文件大小 Args: size_bytes: 字节数 Returns: 格式化后的文件大小字符串 """ if size_bytes == 0: return "0B" size_names = ["B", "KB", "MB", "GB", "TB"] i = 0 while size_bytes >= 1024 and i < len(size_names) - 1: size_bytes /= 1024.0 i += 1 return f"{size_bytes:.1f}{size_names[i]}" @staticmethod def deep_merge_dict(dict1: Dict[str, Any], dict2: Dict[str, Any]) -> Dict[str, Any]: """ 深度合并字典 Args: dict1: 第一个字典 dict2: 第二个字典 Returns: 合并后的字典 """ result = dict1.copy() for key, value in dict2.items(): if key in result and isinstance(result[key], dict) and isinstance(value, dict): result[key] = ToolUtils.deep_merge_dict(result[key], value) else: result[key] = value return result @staticmethod def safe_get_nested(data: Union[Dict, List], path: str, default: Any = None) -> Any: """ 安全获取嵌套数据 Args: data: 数据对象 path: 路径,用点号分隔 (例如: 'user.profile.name') default: 默认值 Returns: 获取到的值或默认值 """ keys = path.split('.') current = data try: for key in keys: if isinstance(current, dict): current = current[key] elif isinstance(current, list): current = current[int(key)] else: return default return current except (KeyError, IndexError, TypeError, ValueError): return default @staticmethod def chunk_list(lst: List[Any], chunk_size: int) -> List[List[Any]]: """ 将列表分块 Args: lst: 要分块的列表 chunk_size: 块大小 Returns: 分块后的列表 """ return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)] @staticmethod def flatten_dict(d: Dict[str, Any], parent_key: str = '', sep: str = '.') -> Dict[str, Any]: """ 扁平化字典 Args: d: 要扁平化的字典 parent_key: 父键名 sep: 分隔符 Returns: 扁平化后的字典 """ items = [] for k, v in d.items(): new_key = f"{parent_key}{sep}{k}" if parent_key else k if isinstance(v, dict): items.extend(ToolUtils.flatten_dict(v, new_key, sep=sep).items()) else: items.append((new_key, v)) return dict(items) # 便捷函数 def generate_uuid() -> str: """生成UUID字符串(便捷函数)""" return ToolUtils.generate_uuid() def generate_trace_id() -> str: """生成追踪ID(便捷函数)""" return ToolUtils.generate_trace_id() def clean_text(text: str) -> str: """清理文本(便捷函数)""" return ToolUtils.clean_text(text) def format_file_size(size_bytes: int) -> str: """格式化文件大小(便捷函数)""" return ToolUtils.format_file_size(size_bytes) # 导出的类和函数 __all__ = [ "DateTimeEncoder", "ToolUtils", "generate_uuid", "generate_trace_id", "clean_text", "format_file_size" ]