| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266 |
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- """
- 工具函数模块
- 提供常用的工具函数和辅助类
- """
- import json
- import datetime
- from typing import Any, Dict, List, Optional, Union
- import hashlib
- import uuid
- import re
- class DateTimeEncoder(json.JSONEncoder):
- """
- 日期时间JSON编码器
- 用于将datetime对象序列化为JSON字符串
- """
- def default(self, obj):
- if isinstance(obj, datetime.datetime):
- return obj.isoformat()
- elif isinstance(obj, datetime.date):
- return obj.isoformat()
- elif isinstance(obj, datetime.time):
- return obj.isoformat()
- elif hasattr(obj, '__dict__'):
- return obj.__dict__
- return super().default(obj)
- class ToolUtils:
- """工具类集合"""
- @staticmethod
- def generate_uuid() -> str:
- """生成UUID字符串"""
- return str(uuid.uuid4())
- @staticmethod
- def generate_trace_id() -> str:
- """生成追踪ID"""
- return str(uuid.uuid4()).replace('-', '')[:16]
- @staticmethod
- def hash_string(text: str, algorithm: str = 'md5') -> str:
- """
- 计算字符串哈希值
- Args:
- text: 要哈希的文本
- algorithm: 哈希算法 ('md5', 'sha1', 'sha256')
- Returns:
- 哈希值字符串
- """
- if algorithm == 'md5':
- return hashlib.md5(text.encode('utf-8')).hexdigest()
- elif algorithm == 'sha1':
- return hashlib.sha1(text.encode('utf-8')).hexdigest()
- elif algorithm == 'sha256':
- return hashlib.sha256(text.encode('utf-8')).hexdigest()
- else:
- raise ValueError(f"Unsupported algorithm: {algorithm}")
- @staticmethod
- def clean_text(text: str) -> str:
- """
- 清理文本,移除多余空白字符
- Args:
- text: 要清理的文本
- Returns:
- 清理后的文本
- """
- # 移除多余的空白字符
- text = re.sub(r'\s+', ' ', text.strip())
- return text
- @staticmethod
- def truncate_text(text: str, max_length: int = 100, suffix: str = "...") -> str:
- """
- 截断文本
- Args:
- text: 要截断的文本
- max_length: 最大长度
- suffix: 截断后缀
- Returns:
- 截断后的文本
- """
- if len(text) <= max_length:
- return text
- return text[:max_length - len(suffix)] + suffix
- @staticmethod
- def extract_emails(text: str) -> List[str]:
- """
- 从文本中提取邮箱地址
- Args:
- text: 要分析的文本
- Returns:
- 邮箱地址列表
- """
- pattern = r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b'
- return re.findall(pattern, text)
- @staticmethod
- def extract_phone_numbers(text: str) -> List[str]:
- """
- 从文本中提取手机号码
- Args:
- text: 要分析的文本
- Returns:
- 手机号码列表
- """
- # 中国大陆手机号码模式
- pattern = r'1[3-9]\d{9}'
- return re.findall(pattern, text)
- @staticmethod
- def format_file_size(size_bytes: int) -> str:
- """
- 格式化文件大小
- Args:
- size_bytes: 字节数
- Returns:
- 格式化后的文件大小字符串
- """
- if size_bytes == 0:
- return "0B"
- size_names = ["B", "KB", "MB", "GB", "TB"]
- i = 0
- while size_bytes >= 1024 and i < len(size_names) - 1:
- size_bytes /= 1024.0
- i += 1
- return f"{size_bytes:.1f}{size_names[i]}"
- @staticmethod
- def deep_merge_dict(dict1: Dict[str, Any], dict2: Dict[str, Any]) -> Dict[str, Any]:
- """
- 深度合并字典
- Args:
- dict1: 第一个字典
- dict2: 第二个字典
- Returns:
- 合并后的字典
- """
- result = dict1.copy()
- for key, value in dict2.items():
- if key in result and isinstance(result[key], dict) and isinstance(value, dict):
- result[key] = ToolUtils.deep_merge_dict(result[key], value)
- else:
- result[key] = value
- return result
- @staticmethod
- def safe_get_nested(data: Union[Dict, List], path: str, default: Any = None) -> Any:
- """
- 安全获取嵌套数据
- Args:
- data: 数据对象
- path: 路径,用点号分隔 (例如: 'user.profile.name')
- default: 默认值
- Returns:
- 获取到的值或默认值
- """
- keys = path.split('.')
- current = data
- try:
- for key in keys:
- if isinstance(current, dict):
- current = current[key]
- elif isinstance(current, list):
- current = current[int(key)]
- else:
- return default
- return current
- except (KeyError, IndexError, TypeError, ValueError):
- return default
- @staticmethod
- def chunk_list(lst: List[Any], chunk_size: int) -> List[List[Any]]:
- """
- 将列表分块
- Args:
- lst: 要分块的列表
- chunk_size: 块大小
- Returns:
- 分块后的列表
- """
- return [lst[i:i + chunk_size] for i in range(0, len(lst), chunk_size)]
- @staticmethod
- def flatten_dict(d: Dict[str, Any], parent_key: str = '', sep: str = '.') -> Dict[str, Any]:
- """
- 扁平化字典
- Args:
- d: 要扁平化的字典
- parent_key: 父键名
- sep: 分隔符
- Returns:
- 扁平化后的字典
- """
- items = []
- for k, v in d.items():
- new_key = f"{parent_key}{sep}{k}" if parent_key else k
- if isinstance(v, dict):
- items.extend(ToolUtils.flatten_dict(v, new_key, sep=sep).items())
- else:
- items.append((new_key, v))
- return dict(items)
- # 便捷函数
- def generate_uuid() -> str:
- """生成UUID字符串(便捷函数)"""
- return ToolUtils.generate_uuid()
- def generate_trace_id() -> str:
- """生成追踪ID(便捷函数)"""
- return ToolUtils.generate_trace_id()
- def clean_text(text: str) -> str:
- """清理文本(便捷函数)"""
- return ToolUtils.clean_text(text)
- def format_file_size(size_bytes: int) -> str:
- """格式化文件大小(便捷函数)"""
- return ToolUtils.format_file_size(size_bytes)
- # 导出的类和函数
- __all__ = [
- "DateTimeEncoder",
- "ToolUtils",
- "generate_uuid",
- "generate_trace_id",
- "clean_text",
- "format_file_size"
- ]
|