| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365 |
- """
- API调用日志服务层
- 提供API调用日志记录、计费和统计查询功能
- 需求: 10.1, 10.2, 10.3, 10.4, 10.5, 11.1, 11.2, 11.3, 11.4, 12.1, 12.2, 12.4
- """
- import logging
- import uuid
- from datetime import datetime, date, timedelta
- from decimal import Decimal
- from typing import List, Optional, Tuple
- from sqlalchemy.orm import Session
- from sqlalchemy import func, desc, and_
- from sqlalchemy.exc import IntegrityError
- from app.models.api_call_log import ApiCallLog
- from app.models.user import User
- from app.models.platform_api_key import PlatformApiKey
- from app.schemas.platform_stats import (
- StatsResponse,
- TrendItem,
- ModelDistItem,
- CallLogResponse
- )
- logger = logging.getLogger(__name__)
- class ApiCallLogService:
- """API调用日志服务类"""
- def __init__(self, db: Session):
- self.db = db
- def create_log(
- self,
- user_id: str,
- api_key_id: int,
- model_id: Optional[int],
- model_name: str,
- is_local: bool,
- input_tokens: int,
- output_tokens: int,
- bill: Decimal,
- status: str = "success",
- error_message: Optional[str] = None,
- request_ip: Optional[str] = None
- ) -> ApiCallLog:
- """
- 创建调用日志
-
- 需求 10.3: 记录每次API调用的详细信息
- 需求 10.4: 记录调用时间、模型、Token用量、费用
-
- Args:
- user_id: 用户ID
- api_key_id: API Key ID
- model_id: 模型ID
- model_name: 模型名称
- is_local: 是否为本地模型
- input_tokens: 输入Token数
- output_tokens: 输出Token数
- bill: 费用金额
- status: 调用状态
- error_message: 错误信息
- request_ip: 请求IP
-
- Returns:
- ApiCallLog: 创建的日志记录
- """
- log = ApiCallLog(
- user_id=user_id,
- api_key_id=api_key_id,
- model_id=model_id,
- model_name=model_name,
- is_local=is_local,
- input_tokens=input_tokens,
- output_tokens=output_tokens,
- bill=bill,
- status=status,
- error_message=error_message,
- request_ip=request_ip
- )
- self.db.add(log)
- self.db.commit()
- self.db.refresh(log)
- return log
- def get_user_stats(
- self,
- user_id: str,
- trend_days: int = 7,
- key_type: Optional[str] = None
- ) -> StatsResponse:
- """
- 获取用户调用统计
-
- 需求 11.1: 显示今日调用次数、本月调用次数、总调用次数
- 需求 11.2: 显示今日消费金额、本月消费金额
- 需求 11.3: 显示调用趋势图
- 需求 11.4: 显示各模型调用占比饼图
-
- Args:
- user_id: 用户ID
- trend_days: 趋势数据天数
- key_type: 密钥类型: public 或 local
-
- Returns:
- StatsResponse: 统计响应
- """
- today = date.today()
- month_start = today.replace(day=1)
-
- # 构建基础查询条件
- base_filter = [ApiCallLog.user_id == user_id]
- if key_type:
- # 关联查询API Key类型
- base_filter.append(
- ApiCallLog.api_key_id.in_(
- self.db.query(PlatformApiKey.id).filter(
- PlatformApiKey.user_id == user_id,
- PlatformApiKey.key_type == key_type
- )
- )
- )
-
- # 今日统计
- today_stats = self.db.query(
- func.count(ApiCallLog.id).label('count'),
- func.coalesce(func.sum(ApiCallLog.bill), 0).label('cost')
- ).filter(
- *base_filter,
- func.date(ApiCallLog.created_at) == today
- ).first()
-
- # 本月统计
- month_stats = self.db.query(
- func.count(ApiCallLog.id).label('count'),
- func.coalesce(func.sum(ApiCallLog.bill), 0).label('cost')
- ).filter(
- *base_filter,
- func.date(ApiCallLog.created_at) >= month_start
- ).first()
-
- # 总计统计
- total_stats = self.db.query(
- func.count(ApiCallLog.id).label('count')
- ).filter(
- *base_filter
- ).first()
-
- # 趋势数据
- trend_start = today - timedelta(days=trend_days - 1)
- trend_data = self._get_trend_data(user_id, trend_start, today, key_type)
-
- # 模型分布
- model_distribution = self._get_model_distribution(user_id, key_type)
-
- return StatsResponse(
- today_calls=today_stats.count or 0,
- month_calls=month_stats.count or 0,
- total_calls=total_stats.count or 0,
- today_cost=Decimal(str(today_stats.cost or 0)),
- month_cost=Decimal(str(month_stats.cost or 0)),
- trend_data=trend_data,
- model_distribution=model_distribution
- )
- def _get_trend_data(
- self,
- user_id: str,
- start_date: date,
- end_date: date,
- key_type: Optional[str] = None
- ) -> List[TrendItem]:
- """获取趋势数据"""
- # 构建基础查询条件
- base_filter = [
- ApiCallLog.user_id == user_id,
- func.date(ApiCallLog.created_at) >= start_date,
- func.date(ApiCallLog.created_at) <= end_date
- ]
- if key_type:
- # 关联查询API Key类型
- base_filter.append(
- ApiCallLog.api_key_id.in_(
- self.db.query(PlatformApiKey.id).filter(
- PlatformApiKey.user_id == user_id,
- PlatformApiKey.key_type == key_type
- )
- )
- )
-
- results = self.db.query(
- func.date(ApiCallLog.created_at).label('date'),
- func.count(ApiCallLog.id).label('count'),
- func.coalesce(func.sum(ApiCallLog.bill), 0).label('cost')
- ).filter(
- *base_filter
- ).group_by(
- func.date(ApiCallLog.created_at)
- ).order_by(
- func.date(ApiCallLog.created_at)
- ).all()
-
- # 构建日期到数据的映射
- data_map = {str(r.date): (r.count, r.cost) for r in results}
-
- # 填充所有日期
- trend_data = []
- current = start_date
- while current <= end_date:
- date_str = str(current)
- count, cost = data_map.get(date_str, (0, 0))
- trend_data.append(TrendItem(
- date=date_str,
- count=count,
- cost=Decimal(str(cost))
- ))
- current += timedelta(days=1)
-
- return trend_data
- def _get_model_distribution(self, user_id: str, key_type: Optional[str] = None) -> List[ModelDistItem]:
- """获取模型分布数据"""
- # 构建基础查询条件
- base_filter = [ApiCallLog.user_id == user_id]
- if key_type:
- # 关联查询API Key类型
- base_filter.append(
- ApiCallLog.api_key_id.in_(
- self.db.query(PlatformApiKey.id).filter(
- PlatformApiKey.user_id == user_id,
- PlatformApiKey.key_type == key_type
- )
- )
- )
-
- results = self.db.query(
- ApiCallLog.model_name,
- func.count(ApiCallLog.id).label('count')
- ).filter(
- *base_filter
- ).group_by(
- ApiCallLog.model_name
- ).order_by(
- desc(func.count(ApiCallLog.id))
- ).limit(10).all()
-
- total = sum(r.count for r in results)
- if total == 0:
- return []
-
- return [
- ModelDistItem(
- model_name=r.model_name,
- count=r.count,
- percentage=round(r.count / total * 100, 2)
- )
- for r in results
- ]
- def get_call_logs(
- self,
- user_id: str,
- start_date: Optional[date] = None,
- end_date: Optional[date] = None,
- model_id: Optional[int] = None,
- api_key_id: Optional[int] = None,
- key_type: Optional[str] = None,
- page: int = 1,
- page_size: int = 20
- ) -> Tuple[List[CallLogResponse], int]:
- """
- 获取调用日志列表
-
- 需求 12.1: 支持按时间范围、模型、API Key进行筛选
- 需求 12.2: 显示调用时间、模型名称、Token用量、费用金额、使用的API Key
- 需求 12.4: 支持分页查询
-
- Args:
- user_id: 用户ID
- start_date: 开始日期
- end_date: 结束日期
- model_id: 模型ID
- api_key_id: API Key ID
- key_type: 密钥类型: public 或 local
- page: 页码
- page_size: 每页数量
-
- Returns:
- Tuple[List[CallLogResponse], int]: (日志列表, 总数)
- """
- query = self.db.query(ApiCallLog).filter(ApiCallLog.user_id == user_id)
-
- # 应用筛选条件
- if start_date:
- query = query.filter(func.date(ApiCallLog.created_at) >= start_date)
- if end_date:
- query = query.filter(func.date(ApiCallLog.created_at) <= end_date)
- if model_id:
- query = query.filter(ApiCallLog.model_id == model_id)
- if api_key_id:
- query = query.filter(ApiCallLog.api_key_id == api_key_id)
- if key_type:
- # 关联查询API Key类型
- query = query.filter(
- ApiCallLog.api_key_id.in_(
- self.db.query(PlatformApiKey.id).filter(
- PlatformApiKey.user_id == user_id,
- PlatformApiKey.key_type == key_type
- )
- )
- )
-
- # 获取总数
- total = query.count()
-
- # 分页查询
- logs = query.order_by(desc(ApiCallLog.created_at)).offset(
- (page - 1) * page_size
- ).limit(page_size).all()
-
- # 获取API Key前缀映射
- api_key_ids = [log.api_key_id for log in logs if log.api_key_id]
- api_key_map = {}
- if api_key_ids:
- api_keys = self.db.query(PlatformApiKey).filter(
- PlatformApiKey.id.in_(api_key_ids)
- ).all()
- api_key_map = {k.id: k.api_key_prefix for k in api_keys}
-
- from app.models.model import ModelNew, ModelPriceNew
- from sqlalchemy.orm import selectinload
- model_ids = [log.model_id for log in logs if log.model_id]
- price_map = {}
- if model_ids:
- # 用 selectinload 一次性加载价格,消灭 N+1
- models = self.db.query(ModelNew).options(
- selectinload(ModelNew.prices)
- ).filter(ModelNew.id.in_(model_ids)).all()
- for m in models:
- active_price = next((p for p in (m.prices or []) if p.is_active), None)
- if active_price:
- price_map[m.id] = active_price
-
- # 构建响应
- result = []
- for log in logs:
- price_info = price_map.get(log.model_id) if log.model_id else None
- result.append(CallLogResponse(
- id=log.id,
- model_name=log.model_name,
- is_local=log.is_local,
- input_tokens=log.input_tokens,
- output_tokens=log.output_tokens,
- bill=Decimal(str(log.bill)),
- status=log.status,
- api_key_prefix=api_key_map.get(log.api_key_id, "N/A"),
- created_at=log.created_at,
- input_price=Decimal(str(price_info.input_price_discounted)) if price_info else None,
- output_price=Decimal(str(price_info.output_price_discounted)) if price_info else None
- ))
-
- return result, total
|