admin_model_service.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. """
  2. 管理员模型管理服务(新表结构)
  3. """
  4. from datetime import datetime
  5. from decimal import Decimal
  6. from typing import Optional, List, Tuple
  7. from sqlalchemy.orm import Session
  8. from sqlalchemy import or_, desc
  9. from app.models.model import ModelNew, ModelPriceNew
  10. from app.schemas.admin_schema import (
  11. ModelListParams, ModelListItem, ModelDetailResponse,
  12. ModelCreateRequest, ModelUpdateRequest, ModelPriceRequest
  13. )
  14. class AdminModelService:
  15. def __init__(self, db: Session):
  16. self.db = db
  17. def list_models(self, params: ModelListParams) -> Tuple[List[ModelListItem], int]:
  18. query = self.db.query(ModelNew).filter(ModelNew.is_local == False)
  19. if params.keyword:
  20. kw = f"%{params.keyword}%"
  21. query = query.filter(or_(
  22. ModelNew.model_code.ilike(kw),
  23. ModelNew.display_name.ilike(kw),
  24. ModelNew.supplier.ilike(kw),
  25. ))
  26. if params.category is not None:
  27. query = query.filter(ModelNew.categories.any(params.category))
  28. if params.supplier:
  29. query = query.filter(ModelNew.supplier == params.supplier)
  30. if params.is_show_enabled is not None:
  31. query = query.filter(ModelNew.is_show_enabled == params.is_show_enabled)
  32. if params.is_api_enabled is not None:
  33. query = query.filter(ModelNew.is_api_enabled == params.is_api_enabled)
  34. total = query.count()
  35. from sqlalchemy.orm import selectinload
  36. models = query.order_by(desc(ModelNew.id)).options(
  37. selectinload(ModelNew.prices)
  38. ).offset((params.page - 1) * params.size).limit(params.size).all()
  39. result = [
  40. ModelListItem(
  41. id=m.id,
  42. model_code=m.model_code,
  43. display_name=m.display_name,
  44. img=m.img or "",
  45. category=m.categories[0] if m.categories else 0,
  46. categories=m.categories or [],
  47. supplier=m.supplier,
  48. description=(m.custom_description or m.description or "")[:80] or None,
  49. # 取第一条有效价格记录作摘要
  50. price_label=m.prices[0].label if m.prices else None,
  51. input_price_original=m.prices[0].input_price_original if m.prices else None,
  52. output_price_original=m.prices[0].output_price_original if m.prices else None,
  53. input_price_discounted=m.prices[0].input_price_discounted if m.prices else None,
  54. output_price_discounted=m.prices[0].output_price_discounted if m.prices else None,
  55. discount_rate=m.prices[0].discount_rate if m.prices else None,
  56. discount_label=m.prices[0].discount_label if m.prices else None,
  57. price_unit=m.prices[0].unit if m.prices else None,
  58. is_show_enabled=m.is_show_enabled,
  59. is_api_enabled=m.is_api_enabled,
  60. is_featured=m.is_featured,
  61. is_search=m.is_search,
  62. is_thinking=m.is_thinking,
  63. )
  64. for m in models
  65. ]
  66. return result, total
  67. def get_model_detail(self, model_id: int) -> Optional[ModelDetailResponse]:
  68. model = self.db.query(ModelNew).filter(ModelNew.id == model_id).first()
  69. if not model:
  70. return None
  71. return ModelDetailResponse(
  72. id=model.id,
  73. model_code=model.model_code,
  74. display_name=model.display_name,
  75. img=model.img,
  76. category=model.categories[0] if model.categories else 0,
  77. supplier=model.supplier,
  78. description=model.description,
  79. custom_description=model.custom_description,
  80. tag1=model.tag1,
  81. tag2=model.tag2,
  82. keywords=model.keywords,
  83. is_featured=model.is_featured,
  84. is_search=model.is_search,
  85. is_thinking=model.is_thinking,
  86. is_show_enabled=model.is_show_enabled,
  87. is_api_enabled=model.is_api_enabled,
  88. source_keys=model.source_keys,
  89. normalized_keys=model.normalized_keys,
  90. created_at=model.created_at,
  91. updated_at=model.updated_at,
  92. )
  93. def create_model(self, data: ModelCreateRequest) -> int:
  94. existing = self.db.query(ModelNew).filter(
  95. ModelNew.model_code == data.model_code
  96. ).first()
  97. if existing:
  98. raise ValueError("MODEL_CODE_EXISTS")
  99. model = ModelNew(
  100. model_code=data.model_code,
  101. display_name=data.display_name,
  102. img=data.img,
  103. categories=[data.category] if data.category is not None else [],
  104. supplier=data.supplier,
  105. description=data.description,
  106. custom_description=data.custom_description,
  107. tag1=data.tag1,
  108. tag2=data.tag2,
  109. keywords=data.keywords,
  110. is_featured=data.is_featured,
  111. is_search=data.is_search,
  112. is_thinking=data.is_thinking,
  113. is_show_enabled=data.is_show_enabled,
  114. is_api_enabled=data.is_api_enabled,
  115. )
  116. self.db.add(model)
  117. self.db.commit()
  118. self.db.refresh(model)
  119. return model.id
  120. def update_model(self, model_id: int, data: ModelUpdateRequest) -> bool:
  121. model = self.db.query(ModelNew).filter(ModelNew.id == model_id).first()
  122. if not model:
  123. raise ValueError("MODEL_NOT_FOUND")
  124. for field, value in data.model_dump(exclude_unset=True).items():
  125. setattr(model, field, value)
  126. self.db.commit()
  127. return True
  128. def update_model_price(self, model_id: int, data: ModelPriceRequest) -> bool:
  129. """Upsert 价格:旧记录 is_active=false,插入新记录"""
  130. model = self.db.query(ModelNew).filter(ModelNew.id == model_id).first()
  131. if not model:
  132. raise ValueError("MODEL_NOT_FOUND")
  133. now = datetime.utcnow()
  134. label = data.label or "default"
  135. # 旧记录失效
  136. self.db.query(ModelPriceNew).filter(
  137. ModelPriceNew.model_code == model.model_code,
  138. ModelPriceNew.label == label,
  139. ModelPriceNew.is_active == True,
  140. ).update({"is_active": False})
  141. input_orig = data.input_price or Decimal("0")
  142. output_orig = data.output_price or Decimal("0")
  143. rate = data.discount_rate if data.discount_rate is not None else Decimal("1")
  144. # discount_rate 范围 0~1,1=无折扣
  145. new_price = ModelPriceNew(
  146. model_code=model.model_code,
  147. label=label,
  148. tier_min=data.tier_min,
  149. tier_max=data.tier_max,
  150. tier_unit=data.tier_unit,
  151. input_price_original=input_orig,
  152. output_price_original=output_orig,
  153. discount_rate=rate,
  154. input_price_discounted=input_orig * rate,
  155. output_price_discounted=output_orig * rate,
  156. currency=data.currency or "CNY",
  157. unit=data.unit or "元/每百万tokens",
  158. display_multiplier=data.display_multiplier or 1000000,
  159. source_url=None,
  160. crawled_at=now,
  161. is_active=True,
  162. )
  163. self.db.add(new_price)
  164. self.db.commit()
  165. return True
  166. def update_model_status(self, model_id: int, field: str, value: bool) -> bool:
  167. model = self.db.query(ModelNew).filter(ModelNew.id == model_id).first()
  168. if not model:
  169. raise ValueError("MODEL_NOT_FOUND")
  170. if field not in ("is_show_enabled", "is_api_enabled", "is_featured"):
  171. raise ValueError("INVALID_FIELD")
  172. setattr(model, field, value)
  173. self.db.commit()
  174. return True