""" API调用日志服务层 提供API调用日志记录、计费和统计查询功能 需求: 10.1, 10.2, 10.3, 10.4, 10.5, 11.1, 11.2, 11.3, 11.4, 12.1, 12.2, 12.4 """ import logging import uuid from datetime import datetime, date, timedelta from decimal import Decimal from typing import List, Optional, Tuple from sqlalchemy.orm import Session from sqlalchemy import func, desc, and_ from sqlalchemy.exc import IntegrityError from app.models.api_call_log import ApiCallLog from app.models.user import User from app.models.platform_api_key import PlatformApiKey from app.schemas.platform_stats import ( StatsResponse, TrendItem, ModelDistItem, CallLogResponse ) logger = logging.getLogger(__name__) class ApiCallLogService: """API调用日志服务类""" def __init__(self, db: Session): self.db = db def create_log( self, user_id: str, api_key_id: int, model_id: Optional[int], model_name: str, is_local: bool, input_tokens: int, output_tokens: int, bill: Decimal, status: str = "success", error_message: Optional[str] = None, request_ip: Optional[str] = None ) -> ApiCallLog: """ 创建调用日志 需求 10.3: 记录每次API调用的详细信息 需求 10.4: 记录调用时间、模型、Token用量、费用 Args: user_id: 用户ID api_key_id: API Key ID model_id: 模型ID model_name: 模型名称 is_local: 是否为本地模型 input_tokens: 输入Token数 output_tokens: 输出Token数 bill: 费用金额 status: 调用状态 error_message: 错误信息 request_ip: 请求IP Returns: ApiCallLog: 创建的日志记录 """ log = ApiCallLog( user_id=user_id, api_key_id=api_key_id, model_id=model_id, model_name=model_name, is_local=is_local, input_tokens=input_tokens, output_tokens=output_tokens, bill=bill, status=status, error_message=error_message, request_ip=request_ip ) self.db.add(log) self.db.commit() self.db.refresh(log) return log def get_user_stats( self, user_id: str, trend_days: int = 7, key_type: Optional[str] = None ) -> StatsResponse: """ 获取用户调用统计 需求 11.1: 显示今日调用次数、本月调用次数、总调用次数 需求 11.2: 显示今日消费金额、本月消费金额 需求 11.3: 显示调用趋势图 需求 11.4: 显示各模型调用占比饼图 Args: user_id: 用户ID trend_days: 趋势数据天数 key_type: 密钥类型: public 或 local Returns: StatsResponse: 统计响应 """ today = date.today() month_start = today.replace(day=1) # 构建基础查询条件 base_filter = [ApiCallLog.user_id == user_id] if key_type: # 关联查询API Key类型 base_filter.append( ApiCallLog.api_key_id.in_( self.db.query(PlatformApiKey.id).filter( PlatformApiKey.user_id == user_id, PlatformApiKey.key_type == key_type ) ) ) # 今日统计 today_stats = self.db.query( func.count(ApiCallLog.id).label('count'), func.coalesce(func.sum(ApiCallLog.bill), 0).label('cost') ).filter( *base_filter, func.date(ApiCallLog.created_at) == today ).first() # 本月统计 month_stats = self.db.query( func.count(ApiCallLog.id).label('count'), func.coalesce(func.sum(ApiCallLog.bill), 0).label('cost') ).filter( *base_filter, func.date(ApiCallLog.created_at) >= month_start ).first() # 总计统计 total_stats = self.db.query( func.count(ApiCallLog.id).label('count') ).filter( *base_filter ).first() # 趋势数据 trend_start = today - timedelta(days=trend_days - 1) trend_data = self._get_trend_data(user_id, trend_start, today, key_type) # 模型分布 model_distribution = self._get_model_distribution(user_id, key_type) return StatsResponse( today_calls=today_stats.count or 0, month_calls=month_stats.count or 0, total_calls=total_stats.count or 0, today_cost=Decimal(str(today_stats.cost or 0)), month_cost=Decimal(str(month_stats.cost or 0)), trend_data=trend_data, model_distribution=model_distribution ) def _get_trend_data( self, user_id: str, start_date: date, end_date: date, key_type: Optional[str] = None ) -> List[TrendItem]: """获取趋势数据""" # 构建基础查询条件 base_filter = [ ApiCallLog.user_id == user_id, func.date(ApiCallLog.created_at) >= start_date, func.date(ApiCallLog.created_at) <= end_date ] if key_type: # 关联查询API Key类型 base_filter.append( ApiCallLog.api_key_id.in_( self.db.query(PlatformApiKey.id).filter( PlatformApiKey.user_id == user_id, PlatformApiKey.key_type == key_type ) ) ) results = self.db.query( func.date(ApiCallLog.created_at).label('date'), func.count(ApiCallLog.id).label('count'), func.coalesce(func.sum(ApiCallLog.bill), 0).label('cost') ).filter( *base_filter ).group_by( func.date(ApiCallLog.created_at) ).order_by( func.date(ApiCallLog.created_at) ).all() # 构建日期到数据的映射 data_map = {str(r.date): (r.count, r.cost) for r in results} # 填充所有日期 trend_data = [] current = start_date while current <= end_date: date_str = str(current) count, cost = data_map.get(date_str, (0, 0)) trend_data.append(TrendItem( date=date_str, count=count, cost=Decimal(str(cost)) )) current += timedelta(days=1) return trend_data def _get_model_distribution(self, user_id: str, key_type: Optional[str] = None) -> List[ModelDistItem]: """获取模型分布数据""" # 构建基础查询条件 base_filter = [ApiCallLog.user_id == user_id] if key_type: # 关联查询API Key类型 base_filter.append( ApiCallLog.api_key_id.in_( self.db.query(PlatformApiKey.id).filter( PlatformApiKey.user_id == user_id, PlatformApiKey.key_type == key_type ) ) ) results = self.db.query( ApiCallLog.model_name, func.count(ApiCallLog.id).label('count') ).filter( *base_filter ).group_by( ApiCallLog.model_name ).order_by( desc(func.count(ApiCallLog.id)) ).limit(10).all() total = sum(r.count for r in results) if total == 0: return [] return [ ModelDistItem( model_name=r.model_name, count=r.count, percentage=round(r.count / total * 100, 2) ) for r in results ] def get_call_logs( self, user_id: str, start_date: Optional[date] = None, end_date: Optional[date] = None, model_id: Optional[int] = None, api_key_id: Optional[int] = None, key_type: Optional[str] = None, page: int = 1, page_size: int = 20 ) -> Tuple[List[CallLogResponse], int]: """ 获取调用日志列表 需求 12.1: 支持按时间范围、模型、API Key进行筛选 需求 12.2: 显示调用时间、模型名称、Token用量、费用金额、使用的API Key 需求 12.4: 支持分页查询 Args: user_id: 用户ID start_date: 开始日期 end_date: 结束日期 model_id: 模型ID api_key_id: API Key ID key_type: 密钥类型: public 或 local page: 页码 page_size: 每页数量 Returns: Tuple[List[CallLogResponse], int]: (日志列表, 总数) """ query = self.db.query(ApiCallLog).filter(ApiCallLog.user_id == user_id) # 应用筛选条件 if start_date: query = query.filter(func.date(ApiCallLog.created_at) >= start_date) if end_date: query = query.filter(func.date(ApiCallLog.created_at) <= end_date) if model_id: query = query.filter(ApiCallLog.model_id == model_id) if api_key_id: query = query.filter(ApiCallLog.api_key_id == api_key_id) if key_type: # 关联查询API Key类型 query = query.filter( ApiCallLog.api_key_id.in_( self.db.query(PlatformApiKey.id).filter( PlatformApiKey.user_id == user_id, PlatformApiKey.key_type == key_type ) ) ) # 获取总数 total = query.count() # 分页查询 logs = query.order_by(desc(ApiCallLog.created_at)).offset( (page - 1) * page_size ).limit(page_size).all() # 获取API Key前缀映射 api_key_ids = [log.api_key_id for log in logs if log.api_key_id] api_key_map = {} if api_key_ids: api_keys = self.db.query(PlatformApiKey).filter( PlatformApiKey.id.in_(api_key_ids) ).all() api_key_map = {k.id: k.api_key_prefix for k in api_keys} from app.models.model import ModelNew, ModelPriceNew from sqlalchemy.orm import selectinload model_ids = [log.model_id for log in logs if log.model_id] price_map = {} if model_ids: # 用 selectinload 一次性加载价格,消灭 N+1 models = self.db.query(ModelNew).options( selectinload(ModelNew.prices) ).filter(ModelNew.id.in_(model_ids)).all() for m in models: active_price = next((p for p in (m.prices or []) if p.is_active), None) if active_price: price_map[m.id] = active_price # 构建响应 result = [] for log in logs: price_info = price_map.get(log.model_id) if log.model_id else None result.append(CallLogResponse( id=log.id, model_name=log.model_name, is_local=log.is_local, input_tokens=log.input_tokens, output_tokens=log.output_tokens, bill=Decimal(str(log.bill)), status=log.status, api_key_prefix=api_key_map.get(log.api_key_id, "N/A"), created_at=log.created_at, input_price=Decimal(str(price_info.input_price_discounted)) if price_info else None, output_price=Decimal(str(price_info.output_price_discounted)) if price_info else None )) return result, total