api_call_log_service.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365
  1. """
  2. API调用日志服务层
  3. 提供API调用日志记录、计费和统计查询功能
  4. 需求: 10.1, 10.2, 10.3, 10.4, 10.5, 11.1, 11.2, 11.3, 11.4, 12.1, 12.2, 12.4
  5. """
  6. import logging
  7. import uuid
  8. from datetime import datetime, date, timedelta
  9. from decimal import Decimal
  10. from typing import List, Optional, Tuple
  11. from sqlalchemy.orm import Session
  12. from sqlalchemy import func, desc, and_
  13. from sqlalchemy.exc import IntegrityError
  14. from app.models.api_call_log import ApiCallLog
  15. from app.models.user import User
  16. from app.models.platform_api_key import PlatformApiKey
  17. from app.schemas.platform_stats import (
  18. StatsResponse,
  19. TrendItem,
  20. ModelDistItem,
  21. CallLogResponse
  22. )
  23. logger = logging.getLogger(__name__)
  24. class ApiCallLogService:
  25. """API调用日志服务类"""
  26. def __init__(self, db: Session):
  27. self.db = db
  28. def create_log(
  29. self,
  30. user_id: str,
  31. api_key_id: int,
  32. model_id: Optional[int],
  33. model_name: str,
  34. is_local: bool,
  35. input_tokens: int,
  36. output_tokens: int,
  37. bill: Decimal,
  38. status: str = "success",
  39. error_message: Optional[str] = None,
  40. request_ip: Optional[str] = None
  41. ) -> ApiCallLog:
  42. """
  43. 创建调用日志
  44. 需求 10.3: 记录每次API调用的详细信息
  45. 需求 10.4: 记录调用时间、模型、Token用量、费用
  46. Args:
  47. user_id: 用户ID
  48. api_key_id: API Key ID
  49. model_id: 模型ID
  50. model_name: 模型名称
  51. is_local: 是否为本地模型
  52. input_tokens: 输入Token数
  53. output_tokens: 输出Token数
  54. bill: 费用金额
  55. status: 调用状态
  56. error_message: 错误信息
  57. request_ip: 请求IP
  58. Returns:
  59. ApiCallLog: 创建的日志记录
  60. """
  61. log = ApiCallLog(
  62. user_id=user_id,
  63. api_key_id=api_key_id,
  64. model_id=model_id,
  65. model_name=model_name,
  66. is_local=is_local,
  67. input_tokens=input_tokens,
  68. output_tokens=output_tokens,
  69. bill=bill,
  70. status=status,
  71. error_message=error_message,
  72. request_ip=request_ip
  73. )
  74. self.db.add(log)
  75. self.db.commit()
  76. self.db.refresh(log)
  77. return log
  78. def get_user_stats(
  79. self,
  80. user_id: str,
  81. trend_days: int = 7,
  82. key_type: Optional[str] = None
  83. ) -> StatsResponse:
  84. """
  85. 获取用户调用统计
  86. 需求 11.1: 显示今日调用次数、本月调用次数、总调用次数
  87. 需求 11.2: 显示今日消费金额、本月消费金额
  88. 需求 11.3: 显示调用趋势图
  89. 需求 11.4: 显示各模型调用占比饼图
  90. Args:
  91. user_id: 用户ID
  92. trend_days: 趋势数据天数
  93. key_type: 密钥类型: public 或 local
  94. Returns:
  95. StatsResponse: 统计响应
  96. """
  97. today = date.today()
  98. month_start = today.replace(day=1)
  99. # 构建基础查询条件
  100. base_filter = [ApiCallLog.user_id == user_id]
  101. if key_type:
  102. # 关联查询API Key类型
  103. base_filter.append(
  104. ApiCallLog.api_key_id.in_(
  105. self.db.query(PlatformApiKey.id).filter(
  106. PlatformApiKey.user_id == user_id,
  107. PlatformApiKey.key_type == key_type
  108. )
  109. )
  110. )
  111. # 今日统计
  112. today_stats = self.db.query(
  113. func.count(ApiCallLog.id).label('count'),
  114. func.coalesce(func.sum(ApiCallLog.bill), 0).label('cost')
  115. ).filter(
  116. *base_filter,
  117. func.date(ApiCallLog.created_at) == today
  118. ).first()
  119. # 本月统计
  120. month_stats = self.db.query(
  121. func.count(ApiCallLog.id).label('count'),
  122. func.coalesce(func.sum(ApiCallLog.bill), 0).label('cost')
  123. ).filter(
  124. *base_filter,
  125. func.date(ApiCallLog.created_at) >= month_start
  126. ).first()
  127. # 总计统计
  128. total_stats = self.db.query(
  129. func.count(ApiCallLog.id).label('count')
  130. ).filter(
  131. *base_filter
  132. ).first()
  133. # 趋势数据
  134. trend_start = today - timedelta(days=trend_days - 1)
  135. trend_data = self._get_trend_data(user_id, trend_start, today, key_type)
  136. # 模型分布
  137. model_distribution = self._get_model_distribution(user_id, key_type)
  138. return StatsResponse(
  139. today_calls=today_stats.count or 0,
  140. month_calls=month_stats.count or 0,
  141. total_calls=total_stats.count or 0,
  142. today_cost=Decimal(str(today_stats.cost or 0)),
  143. month_cost=Decimal(str(month_stats.cost or 0)),
  144. trend_data=trend_data,
  145. model_distribution=model_distribution
  146. )
  147. def _get_trend_data(
  148. self,
  149. user_id: str,
  150. start_date: date,
  151. end_date: date,
  152. key_type: Optional[str] = None
  153. ) -> List[TrendItem]:
  154. """获取趋势数据"""
  155. # 构建基础查询条件
  156. base_filter = [
  157. ApiCallLog.user_id == user_id,
  158. func.date(ApiCallLog.created_at) >= start_date,
  159. func.date(ApiCallLog.created_at) <= end_date
  160. ]
  161. if key_type:
  162. # 关联查询API Key类型
  163. base_filter.append(
  164. ApiCallLog.api_key_id.in_(
  165. self.db.query(PlatformApiKey.id).filter(
  166. PlatformApiKey.user_id == user_id,
  167. PlatformApiKey.key_type == key_type
  168. )
  169. )
  170. )
  171. results = self.db.query(
  172. func.date(ApiCallLog.created_at).label('date'),
  173. func.count(ApiCallLog.id).label('count'),
  174. func.coalesce(func.sum(ApiCallLog.bill), 0).label('cost')
  175. ).filter(
  176. *base_filter
  177. ).group_by(
  178. func.date(ApiCallLog.created_at)
  179. ).order_by(
  180. func.date(ApiCallLog.created_at)
  181. ).all()
  182. # 构建日期到数据的映射
  183. data_map = {str(r.date): (r.count, r.cost) for r in results}
  184. # 填充所有日期
  185. trend_data = []
  186. current = start_date
  187. while current <= end_date:
  188. date_str = str(current)
  189. count, cost = data_map.get(date_str, (0, 0))
  190. trend_data.append(TrendItem(
  191. date=date_str,
  192. count=count,
  193. cost=Decimal(str(cost))
  194. ))
  195. current += timedelta(days=1)
  196. return trend_data
  197. def _get_model_distribution(self, user_id: str, key_type: Optional[str] = None) -> List[ModelDistItem]:
  198. """获取模型分布数据"""
  199. # 构建基础查询条件
  200. base_filter = [ApiCallLog.user_id == user_id]
  201. if key_type:
  202. # 关联查询API Key类型
  203. base_filter.append(
  204. ApiCallLog.api_key_id.in_(
  205. self.db.query(PlatformApiKey.id).filter(
  206. PlatformApiKey.user_id == user_id,
  207. PlatformApiKey.key_type == key_type
  208. )
  209. )
  210. )
  211. results = self.db.query(
  212. ApiCallLog.model_name,
  213. func.count(ApiCallLog.id).label('count')
  214. ).filter(
  215. *base_filter
  216. ).group_by(
  217. ApiCallLog.model_name
  218. ).order_by(
  219. desc(func.count(ApiCallLog.id))
  220. ).limit(10).all()
  221. total = sum(r.count for r in results)
  222. if total == 0:
  223. return []
  224. return [
  225. ModelDistItem(
  226. model_name=r.model_name,
  227. count=r.count,
  228. percentage=round(r.count / total * 100, 2)
  229. )
  230. for r in results
  231. ]
  232. def get_call_logs(
  233. self,
  234. user_id: str,
  235. start_date: Optional[date] = None,
  236. end_date: Optional[date] = None,
  237. model_id: Optional[int] = None,
  238. api_key_id: Optional[int] = None,
  239. key_type: Optional[str] = None,
  240. page: int = 1,
  241. page_size: int = 20
  242. ) -> Tuple[List[CallLogResponse], int]:
  243. """
  244. 获取调用日志列表
  245. 需求 12.1: 支持按时间范围、模型、API Key进行筛选
  246. 需求 12.2: 显示调用时间、模型名称、Token用量、费用金额、使用的API Key
  247. 需求 12.4: 支持分页查询
  248. Args:
  249. user_id: 用户ID
  250. start_date: 开始日期
  251. end_date: 结束日期
  252. model_id: 模型ID
  253. api_key_id: API Key ID
  254. key_type: 密钥类型: public 或 local
  255. page: 页码
  256. page_size: 每页数量
  257. Returns:
  258. Tuple[List[CallLogResponse], int]: (日志列表, 总数)
  259. """
  260. query = self.db.query(ApiCallLog).filter(ApiCallLog.user_id == user_id)
  261. # 应用筛选条件
  262. if start_date:
  263. query = query.filter(func.date(ApiCallLog.created_at) >= start_date)
  264. if end_date:
  265. query = query.filter(func.date(ApiCallLog.created_at) <= end_date)
  266. if model_id:
  267. query = query.filter(ApiCallLog.model_id == model_id)
  268. if api_key_id:
  269. query = query.filter(ApiCallLog.api_key_id == api_key_id)
  270. if key_type:
  271. # 关联查询API Key类型
  272. query = query.filter(
  273. ApiCallLog.api_key_id.in_(
  274. self.db.query(PlatformApiKey.id).filter(
  275. PlatformApiKey.user_id == user_id,
  276. PlatformApiKey.key_type == key_type
  277. )
  278. )
  279. )
  280. # 获取总数
  281. total = query.count()
  282. # 分页查询
  283. logs = query.order_by(desc(ApiCallLog.created_at)).offset(
  284. (page - 1) * page_size
  285. ).limit(page_size).all()
  286. # 获取API Key前缀映射
  287. api_key_ids = [log.api_key_id for log in logs if log.api_key_id]
  288. api_key_map = {}
  289. if api_key_ids:
  290. api_keys = self.db.query(PlatformApiKey).filter(
  291. PlatformApiKey.id.in_(api_key_ids)
  292. ).all()
  293. api_key_map = {k.id: k.api_key_prefix for k in api_keys}
  294. from app.models.model import ModelNew, ModelPriceNew
  295. from sqlalchemy.orm import selectinload
  296. model_ids = [log.model_id for log in logs if log.model_id]
  297. price_map = {}
  298. if model_ids:
  299. # 用 selectinload 一次性加载价格,消灭 N+1
  300. models = self.db.query(ModelNew).options(
  301. selectinload(ModelNew.prices)
  302. ).filter(ModelNew.id.in_(model_ids)).all()
  303. for m in models:
  304. active_price = next((p for p in (m.prices or []) if p.is_active), None)
  305. if active_price:
  306. price_map[m.id] = active_price
  307. # 构建响应
  308. result = []
  309. for log in logs:
  310. price_info = price_map.get(log.model_id) if log.model_id else None
  311. result.append(CallLogResponse(
  312. id=log.id,
  313. model_name=log.model_name,
  314. is_local=log.is_local,
  315. input_tokens=log.input_tokens,
  316. output_tokens=log.output_tokens,
  317. bill=Decimal(str(log.bill)),
  318. status=log.status,
  319. api_key_prefix=api_key_map.get(log.api_key_id, "N/A"),
  320. created_at=log.created_at,
  321. input_price=Decimal(str(price_info.input_price_discounted)) if price_info else None,
  322. output_price=Decimal(str(price_info.output_price_discounted)) if price_info else None
  323. ))
  324. return result, total