model_service.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. """
  2. 模型服务层(新表结构)
  3. """
  4. from typing import Optional, List
  5. from sqlalchemy import or_, case
  6. from sqlalchemy.orm import Session, selectinload
  7. from app.models.model import ModelNew, ModelPriceNew, ModelCategory
  8. from app.schemas.model_schema import (
  9. ModelListResponse,
  10. ModelPriceNewResponse,
  11. PaginatedResponse,
  12. KeywordsResponse,
  13. ModelPricingResponse,
  14. )
  15. CATEGORY_NAME_TO_VALUE = {
  16. # 前端 ModelSquare 分类 Tab 名称 → 数据库 categories 枚举值
  17. # 语言模型包含 LLM(0) 和 MULTIMODAL(1),多模态模型也可用于对话
  18. "语言模型": [0, 1],
  19. "多模态模型": [1],
  20. "TTS模型": [2],
  21. "STT模型": [3],
  22. "生图模型": [4],
  23. "生视频模型": [5],
  24. "图像编辑": [6],
  25. "Embedding": [7],
  26. "Rerank": [8],
  27. # 兼容旧的短名称
  28. "文本": [0],
  29. "图像": [1],
  30. "视频": [5],
  31. "语音": [2],
  32. }
  33. class ModelService:
  34. def __init__(self, db: Session):
  35. self.db = db
  36. def _build_list_response(self, item: ModelNew) -> ModelListResponse:
  37. prices = [
  38. ModelPriceNewResponse(
  39. id=p.id,
  40. label=p.label,
  41. tier_min=p.tier_min,
  42. tier_max=p.tier_max,
  43. tier_unit=p.tier_unit,
  44. input_price_original=p.input_price_original,
  45. output_price_original=p.output_price_original,
  46. discount_rate=p.discount_rate,
  47. discount_label=p.discount_label,
  48. input_price_discounted=p.input_price_discounted,
  49. output_price_discounted=p.output_price_discounted,
  50. currency=p.currency,
  51. unit=p.unit,
  52. display_multiplier=p.display_multiplier,
  53. )
  54. for p in (item.prices or [])
  55. ]
  56. return ModelListResponse(
  57. id=item.id,
  58. model_code=item.model_code,
  59. display_name=item.display_name,
  60. img=item.img,
  61. tag1=item.tag1,
  62. tag2=item.tag2,
  63. description=item.description,
  64. custom_description=item.custom_description,
  65. keywords=item.keywords,
  66. display_tags=item.display_tags,
  67. categories=item.categories or [],
  68. category=item.categories[0] if item.categories else 0,
  69. supplier=item.supplier,
  70. is_featured=item.is_featured,
  71. is_search=item.is_search,
  72. is_thinking=item.is_thinking,
  73. is_api_enabled=item.is_api_enabled,
  74. is_show_enabled=item.is_show_enabled,
  75. created_at=item.created_at,
  76. updated_at=item.updated_at,
  77. # 前端兼容字段
  78. title=item.model_code,
  79. name=item.display_name or item.model_code,
  80. keyword=item.keywords,
  81. prices=prices,
  82. )
  83. def get_models(
  84. self,
  85. page: int = 1,
  86. page_size: int = 20,
  87. keyword: Optional[str] = None,
  88. category: Optional[str] = None,
  89. supplier: Optional[str] = None,
  90. group_name: Optional[str] = None,
  91. filter_keyword: Optional[str] = None,
  92. filter_tag: Optional[str] = None,
  93. is_api_enabled: Optional[bool] = None,
  94. ) -> PaginatedResponse[ModelListResponse]:
  95. query = self.db.query(ModelNew).options(
  96. selectinload(ModelNew.prices)
  97. ).filter(ModelNew.is_show_enabled == True)
  98. if keyword and keyword.strip():
  99. s = f"%{keyword}%"
  100. query = query.filter(
  101. or_(
  102. ModelNew.model_code.ilike(s),
  103. ModelNew.display_name.ilike(s),
  104. ModelNew.description.ilike(s),
  105. ModelNew.custom_description.ilike(s),
  106. ModelNew.keywords.ilike(s),
  107. )
  108. )
  109. if category and category.strip():
  110. vals = CATEGORY_NAME_TO_VALUE.get(category.strip())
  111. if vals is not None:
  112. from sqlalchemy import cast
  113. from sqlalchemy.dialects.postgresql import ARRAY, INTEGER
  114. if len(vals) == 1:
  115. query = query.filter(
  116. ModelNew.categories.contains(cast([vals[0]], ARRAY(INTEGER)))
  117. )
  118. else:
  119. # 多个分类值用 OR 连接(如语言模型包含 LLM 和 MULTIMODAL)
  120. query = query.filter(
  121. or_(*[
  122. ModelNew.categories.contains(cast([v], ARRAY(INTEGER)))
  123. for v in vals
  124. ])
  125. )
  126. if supplier and supplier.strip():
  127. query = query.filter(ModelNew.supplier == supplier.strip())
  128. if group_name and group_name.strip():
  129. query = query.filter(ModelNew.group_name == group_name.strip())
  130. if filter_keyword and filter_keyword.strip():
  131. query = query.filter(ModelNew.keywords == filter_keyword.strip())
  132. if filter_tag and filter_tag.strip():
  133. tag = filter_tag.strip()
  134. query = query.filter(or_(ModelNew.tag1 == tag, ModelNew.tag2 == tag))
  135. if is_api_enabled is not None:
  136. query = query.filter(ModelNew.is_api_enabled == is_api_enabled)
  137. query = query.order_by(
  138. case((ModelNew.model_code == "qwen3-max", 0), else_=1),
  139. ModelNew.is_api_enabled.desc(),
  140. ModelNew.created_at.desc(),
  141. )
  142. total = query.count()
  143. items = query.offset((page - 1) * page_size).limit(page_size).all()
  144. return PaginatedResponse(
  145. total=total,
  146. page=page,
  147. page_size=page_size,
  148. items=[self._build_list_response(i) for i in items],
  149. )
  150. def get_model_by_id(self, model_id: int) -> ModelListResponse:
  151. from fastapi import HTTPException
  152. model = self.db.query(ModelNew).options(
  153. selectinload(ModelNew.prices)
  154. ).filter(ModelNew.id == model_id).first()
  155. if not model:
  156. raise HTTPException(status_code=404, detail="Model not found")
  157. return self._build_list_response(model)
  158. def get_keywords(self) -> KeywordsResponse:
  159. rows = self.db.query(ModelNew.keywords).filter(
  160. ModelNew.keywords.isnot(None),
  161. ModelNew.keywords != "",
  162. ModelNew.is_show_enabled == True,
  163. ).distinct().order_by(ModelNew.keywords).all()
  164. return KeywordsResponse(keywords=[r[0] for r in rows if r[0]])
  165. def get_group_names(self) -> List[str]:
  166. """获取所有模型分组名称(去重排序)"""
  167. from sqlalchemy import distinct
  168. rows = self.db.query(distinct(ModelNew.group_name)).filter(
  169. ModelNew.group_name.isnot(None),
  170. ModelNew.group_name != "",
  171. ModelNew.is_show_enabled == True,
  172. ModelNew.is_local == False,
  173. ).order_by(ModelNew.group_name).all()
  174. return [r[0] for r in rows if r[0]]
  175. def get_featured_models(self, limit: int = 3) -> List[ModelListResponse]:
  176. items = self.db.query(ModelNew).options(
  177. selectinload(ModelNew.prices)
  178. ).filter(
  179. ModelNew.is_featured == True,
  180. ModelNew.is_show_enabled == True,
  181. ).limit(limit).all()
  182. return [self._build_list_response(i) for i in items]
  183. def get_pricing(self, model_code: str) -> ModelPricingResponse:
  184. """获取模型定价信息,支持 source_keys / normalized_keys 查找"""
  185. import re
  186. def normalize(v: str) -> str:
  187. v = v.strip()
  188. v = re.sub(r"[\u200B-\u200D\uFEFF]", "", v)
  189. v = re.sub(r"\s*-\s*", "-", v)
  190. return re.sub(r"\s+", " ", v)
  191. normalized = normalize(model_code)
  192. # 先精确匹配 model_code
  193. model = self.db.query(ModelNew).filter(
  194. ModelNew.model_code == model_code
  195. ).first()
  196. # 再查 source_keys / normalized_keys 数组
  197. if not model:
  198. from sqlalchemy import func, cast
  199. from sqlalchemy.dialects.postgresql import ARRAY, TEXT
  200. model = self.db.query(ModelNew).filter(
  201. or_(
  202. ModelNew.source_keys.any(model_code),
  203. ModelNew.normalized_keys.any(normalized),
  204. )
  205. ).first()
  206. prices = []
  207. if model:
  208. price_rows = self.db.query(ModelPriceNew).filter(
  209. ModelPriceNew.model_code == model.model_code,
  210. ModelPriceNew.is_active == True,
  211. ).order_by(ModelPriceNew.label).all()
  212. prices = [
  213. ModelPriceNewResponse(
  214. id=p.id,
  215. label=p.label,
  216. tier_min=p.tier_min,
  217. tier_max=p.tier_max,
  218. tier_unit=p.tier_unit,
  219. input_price_original=p.input_price_original,
  220. output_price_original=p.output_price_original,
  221. discount_rate=p.discount_rate,
  222. discount_label=p.discount_label,
  223. input_price_discounted=p.input_price_discounted,
  224. output_price_discounted=p.output_price_discounted,
  225. currency=p.currency,
  226. unit=p.unit,
  227. display_multiplier=p.display_multiplier,
  228. )
  229. for p in price_rows
  230. ]
  231. return ModelPricingResponse(
  232. model_code=model.model_code if model else model_code,
  233. display_name=model.display_name if model else None,
  234. description=model.description if model else None,
  235. custom_description=model.custom_description if model else None,
  236. features=model.features if model else None,
  237. rate_limits=model.rate_limits if model else None,
  238. tool_call_prices=model.tool_call_prices if model else None,
  239. display_tags=model.display_tags if model else None,
  240. input_modalities=model.input_modalities if model else None,
  241. output_modalities=model.output_modalities if model else None,
  242. categories=model.categories if model else [],
  243. source_keys=model.source_keys if model else None,
  244. normalized_keys=model.normalized_keys if model else None,
  245. is_api_enabled=model.is_api_enabled if model else False,
  246. prices=prices,
  247. )