admin_stats_service.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. """
  2. 管理员数据统计服务
  3. 提供多维度的数据统计和分析,包括用户统计、API调用统计、模型使用统计等
  4. """
  5. import json
  6. import hashlib
  7. from datetime import date, datetime, timedelta
  8. from typing import Optional, List, Dict, Any
  9. import logging
  10. from sqlalchemy import func
  11. from sqlalchemy.orm import Session
  12. from app.models.user import User
  13. from app.models.api_call_log import ApiCallLog
  14. logger = logging.getLogger(__name__)
  15. # 统计缓存 TTL(秒)
  16. _STATS_CACHE_TTL = 300 # 5 分钟
  17. def _stats_cache_key(method: str, **kwargs) -> str:
  18. params = json.dumps(kwargs, default=str, sort_keys=True)
  19. return f"admin_stats:{method}:{hashlib.md5(params.encode()).hexdigest()}"
  20. def _get_sync_redis():
  21. try:
  22. from app.core.redis import redis_manager
  23. return redis_manager.get_sync_client()
  24. except Exception:
  25. return None
  26. def _cache_get(key: str):
  27. r = _get_sync_redis()
  28. if not r:
  29. return None
  30. try:
  31. raw = r.get(key)
  32. return json.loads(raw) if raw else None
  33. except Exception:
  34. return None
  35. def _cache_set(key: str, value, ttl: int = _STATS_CACHE_TTL):
  36. r = _get_sync_redis()
  37. if not r:
  38. return
  39. try:
  40. r.setex(key, ttl, json.dumps(value, default=str))
  41. except Exception as e:
  42. logger.debug(f"统计缓存写入失败: {e}")
  43. def stats_cached(ttl: int = _STATS_CACHE_TTL):
  44. import functools
  45. def decorator(fn):
  46. @functools.wraps(fn)
  47. async def wrapper(self, *args, **kwargs):
  48. import inspect
  49. sig = inspect.signature(fn)
  50. bound = sig.bind(self, *args, **kwargs)
  51. bound.apply_defaults()
  52. params = {k: v for k, v in bound.arguments.items() if k != "self"}
  53. key = _stats_cache_key(fn.__name__, **params)
  54. cached = _cache_get(key)
  55. if cached is not None:
  56. return cached
  57. result = await fn(self, *args, **kwargs)
  58. _cache_set(key, result, ttl)
  59. return result
  60. return wrapper
  61. return decorator
  62. class AdminStatsService:
  63. """管理员数据统计服务类"""
  64. def __init__(self, db: Session):
  65. self.db = db
  66. @stats_cached()
  67. async def get_user_overview(
  68. self,
  69. start_date: Optional[date] = None,
  70. end_date: Optional[date] = None
  71. ) -> Dict[str, Any]:
  72. """获取用户统计概览"""
  73. new_users_query = self.db.query(func.count(User.id))
  74. if start_date:
  75. new_users_query = new_users_query.filter(func.date(User.created_at) >= start_date)
  76. if end_date:
  77. new_users_query = new_users_query.filter(func.date(User.created_at) <= end_date)
  78. new_users = new_users_query.scalar() or 0
  79. total_users = self.db.query(func.count(User.id)).scalar() or 0
  80. # 活跃用户数:有API调用记录的用户
  81. active_query = self.db.query(func.count(func.distinct(ApiCallLog.user_id)))
  82. if start_date:
  83. active_query = active_query.filter(func.date(ApiCallLog.created_at) >= start_date)
  84. if end_date:
  85. active_query = active_query.filter(func.date(ApiCallLog.created_at) <= end_date)
  86. active_users = active_query.scalar() or 0
  87. retention_rate = (active_users / new_users * 100) if new_users > 0 else 0
  88. return {
  89. "new_users": new_users,
  90. "total_users": total_users,
  91. "active_users": active_users,
  92. "retention_rate": round(retention_rate, 2)
  93. }
  94. @stats_cached()
  95. async def get_user_growth_trend(
  96. self,
  97. start_date: date,
  98. end_date: date
  99. ) -> List[Dict[str, Any]]:
  100. """获取用户增长趋势"""
  101. new_users_by_date = self.db.query(
  102. func.date(User.created_at).label('date'),
  103. func.count(User.id).label('count')
  104. ).filter(
  105. func.date(User.created_at) >= start_date,
  106. func.date(User.created_at) <= end_date
  107. ).group_by(func.date(User.created_at)).all()
  108. # 活跃用户:有API调用的用户
  109. active_users_by_date = self.db.query(
  110. func.date(ApiCallLog.created_at).label('date'),
  111. func.count(func.distinct(ApiCallLog.user_id)).label('count')
  112. ).filter(
  113. func.date(ApiCallLog.created_at) >= start_date,
  114. func.date(ApiCallLog.created_at) <= end_date
  115. ).group_by(func.date(ApiCallLog.created_at)).all()
  116. date_range = []
  117. current = start_date
  118. while current <= end_date:
  119. date_range.append(current)
  120. current += timedelta(days=1)
  121. new_users_dict = {item.date: item.count for item in new_users_by_date}
  122. active_users_dict = {item.date: item.count for item in active_users_by_date}
  123. return [
  124. {
  125. "date": d.isoformat(),
  126. "new_users": new_users_dict.get(d, 0),
  127. "active_users": active_users_dict.get(d, 0)
  128. }
  129. for d in date_range
  130. ]
  131. @stats_cached()
  132. async def get_business_overview(
  133. self,
  134. start_date: Optional[date] = None,
  135. end_date: Optional[date] = None
  136. ) -> Dict[str, Any]:
  137. """获取 API 调用统计概览"""
  138. query = self.db.query(func.count(ApiCallLog.id))
  139. if start_date:
  140. query = query.filter(func.date(ApiCallLog.created_at) >= start_date)
  141. if end_date:
  142. query = query.filter(func.date(ApiCallLog.created_at) <= end_date)
  143. total_calls = query.scalar() or 0
  144. success_query = self.db.query(func.count(ApiCallLog.id)).filter(ApiCallLog.status == 'success')
  145. if start_date:
  146. success_query = success_query.filter(func.date(ApiCallLog.created_at) >= start_date)
  147. if end_date:
  148. success_query = success_query.filter(func.date(ApiCallLog.created_at) <= end_date)
  149. success_calls = success_query.scalar() or 0
  150. return {
  151. "total_calls": total_calls,
  152. "success_calls": success_calls,
  153. "failed_calls": total_calls - success_calls,
  154. }
  155. @stats_cached()
  156. async def get_business_trend(
  157. self,
  158. start_date: date,
  159. end_date: date
  160. ) -> List[Dict[str, Any]]:
  161. """获取 API 调用趋势"""
  162. calls_by_date = self.db.query(
  163. func.date(ApiCallLog.created_at).label('date'),
  164. func.count(ApiCallLog.id).label('count')
  165. ).filter(
  166. func.date(ApiCallLog.created_at) >= start_date,
  167. func.date(ApiCallLog.created_at) <= end_date
  168. ).group_by(func.date(ApiCallLog.created_at)).all()
  169. date_range = []
  170. current = start_date
  171. while current <= end_date:
  172. date_range.append(current)
  173. current += timedelta(days=1)
  174. calls_dict = {item.date: item.count for item in calls_by_date}
  175. return [
  176. {
  177. "date": d.isoformat(),
  178. "api_calls": calls_dict.get(d, 0),
  179. }
  180. for d in date_range
  181. ]
  182. @stats_cached()
  183. async def get_model_usage_ranking(
  184. self,
  185. start_date: Optional[date] = None,
  186. end_date: Optional[date] = None,
  187. top_n: int = 10
  188. ) -> List[Dict[str, Any]]:
  189. """获取模型使用排行(基于 API 调用日志)"""
  190. query = self.db.query(
  191. ApiCallLog.model_name,
  192. func.count(ApiCallLog.id).label('count')
  193. )
  194. if start_date:
  195. query = query.filter(func.date(ApiCallLog.created_at) >= start_date)
  196. if end_date:
  197. query = query.filter(func.date(ApiCallLog.created_at) <= end_date)
  198. query = query.filter(ApiCallLog.model_name.isnot(None))
  199. query = query.group_by(ApiCallLog.model_name)
  200. query = query.order_by(func.count(ApiCallLog.id).desc())
  201. query = query.limit(top_n)
  202. results = query.all()
  203. total_count = sum(r.count for r in results)
  204. return [
  205. {
  206. "model_name": r.model_name,
  207. "usage_count": r.count,
  208. "percentage": round((r.count / total_count * 100) if total_count > 0 else 0, 2)
  209. }
  210. for r in results
  211. ]
  212. @stats_cached()
  213. async def get_dashboard_metrics(self) -> Dict[str, Any]:
  214. """获取仪表盘核心指标"""
  215. today = date.today()
  216. yesterday = today - timedelta(days=1)
  217. seven_days_ago = today - timedelta(days=7)
  218. # 今日新增用户
  219. today_new_users = self.db.query(func.count(User.id)).filter(
  220. func.date(User.created_at) == today
  221. ).scalar() or 0
  222. yesterday_new_users = self.db.query(func.count(User.id)).filter(
  223. func.date(User.created_at) == yesterday
  224. ).scalar() or 0
  225. today_new_users_growth = (
  226. ((today_new_users - yesterday_new_users) / yesterday_new_users * 100)
  227. if yesterday_new_users > 0 else 0.0
  228. )
  229. total_users = self.db.query(func.count(User.id)).scalar() or 0
  230. # 活跃用户数(近7日有API调用的用户)
  231. active_users = self.db.query(func.count(func.distinct(ApiCallLog.user_id))).filter(
  232. func.date(ApiCallLog.created_at) >= seven_days_ago
  233. ).scalar() or 0
  234. # API调用量(今日)
  235. today_api_calls = self.db.query(func.count(ApiCallLog.id)).filter(
  236. func.date(ApiCallLog.created_at) == today
  237. ).scalar() or 0
  238. return {
  239. "today_new_users": today_new_users,
  240. "today_new_users_growth": round(today_new_users_growth, 2),
  241. "total_users": total_users,
  242. "active_users": active_users,
  243. "today_api_calls": today_api_calls,
  244. }