model.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. """
  2. 模型数据ORM定义(新表结构)
  3. 对应 models_new、model_price_new、crawler_sync_log 三张表
  4. """
  5. from enum import IntEnum
  6. from sqlalchemy import Boolean, Column, Integer, String, Text, DateTime, Index, Numeric, ForeignKey
  7. from sqlalchemy.dialects.postgresql import ARRAY, JSONB
  8. from sqlalchemy.sql import func
  9. from sqlalchemy.orm import relationship
  10. from app.database import Base
  11. class ModelCategory(IntEnum):
  12. """模型分类枚举"""
  13. LLM = 0
  14. MULTIMODAL = 1
  15. TTS = 2
  16. STT = 3
  17. IMAGE_GEN = 4
  18. VIDEO_GEN = 5
  19. IMAGE_EDIT = 6
  20. EMBEDDING = 7
  21. RERANK = 8
  22. def infer_categories(input_modalities: list, output_modalities: list, model_code: str = "", display_tags: list = None) -> list:
  23. """
  24. 根据 input/output_modalities、model_code、display_tags 推断模型分类数组。
  25. 规则优先级:
  26. 1. model_code 关键词(最可靠)
  27. 2. output_modalities 主输出类型
  28. 3. input_modalities 辅助判断
  29. 4. display_tags 辅助修正
  30. """
  31. inp = set(m.lower() for m in (input_modalities or []))
  32. out = set(m.lower() for m in (output_modalities or []))
  33. code = (model_code or "").lower()
  34. tags = set(t.lower() for t in (display_tags or []))
  35. cats = set()
  36. # ── 1. 视频生成(输出 Video,或 model_code 含视频关键词)──────────────
  37. if 'video' in out or 'v2v' in code or 'i2v' in code or 't2v' in code or 's2v' in code:
  38. cats.add(ModelCategory.VIDEO_GEN)
  39. # ── 2. 图像翻译(model_code 含 mt-image)
  40. # 这类模型 input/output 都是 Image,但不是图生图,跳过图像分类
  41. is_image_translation = 'mt-image' in code
  42. # ── 3. 图像生成 / 图像编辑(输出 Image,且不是视频模型,且不是图像翻译)
  43. if 'image' in out and 'video' not in out and not is_image_translation:
  44. if 'text' in inp:
  45. cats.add(ModelCategory.IMAGE_GEN)
  46. if 'image' in inp:
  47. cats.add(ModelCategory.IMAGE_EDIT)
  48. # ── 4. TTS(输出 Audio,输入 Text,且不是视频模型,且不是全模态对话模型)
  49. # 全模态对话模型特征:输入同时有 Text+Image+Audio(或 Video)
  50. is_omni = 'audio' in inp and ('image' in inp or 'video' in inp)
  51. if 'audio' in out and 'text' in inp and 'video' not in out and not is_omni:
  52. cats.add(ModelCategory.TTS)
  53. # ── 5. STT(输入 Audio,输出 Text,且不是视频模型,且不是全模态对话模型)
  54. if 'audio' in inp and 'text' in out and 'video' not in out and 'image' not in out and not is_omni:
  55. cats.add(ModelCategory.STT)
  56. # ── 6. Embedding(model_code 或 tags 含关键词)──────────────────────────
  57. if any(k in code for k in ('embed', 'embedding')) or '向量' in tags or 'embedding' in tags:
  58. cats.add(ModelCategory.EMBEDDING)
  59. # ── 7. Rerank(model_code 或 tags 含关键词)─────────────────────────────
  60. if any(k in code for k in ('rerank', 'reranker')) or '重排' in tags or 'rerank' in tags:
  61. cats.add(ModelCategory.RERANK)
  62. # ── 8. 多模态(输入有 Image 或 Video,输出有 Text,且不是图像生成/视频生成)
  63. if ('image' in inp or 'video' in inp) and 'text' in out and not cats.intersection({
  64. ModelCategory.IMAGE_GEN, ModelCategory.IMAGE_EDIT, ModelCategory.VIDEO_GEN
  65. }):
  66. cats.add(ModelCategory.MULTIMODAL)
  67. # ── 9. 兜底:纯文本 → LLM ────────────────────────────────────────────────
  68. if not cats:
  69. cats.add(ModelCategory.LLM)
  70. return sorted(int(c) for c in cats)
  71. class ModelNew(Base):
  72. """models_new 主表"""
  73. __tablename__ = "models_new"
  74. id = Column(Integer, primary_key=True, autoincrement=True)
  75. model_code = Column(String(200), nullable=False, unique=True, comment="模型唯一标识")
  76. # ===== 爬虫字段 =====
  77. description = Column(Text, nullable=True)
  78. display_tags = Column(ARRAY(Text), nullable=True)
  79. input_modalities = Column(ARRAY(Text), nullable=True)
  80. output_modalities = Column(ARRAY(Text), nullable=True)
  81. features = Column(JSONB, nullable=True)
  82. rate_limits = Column(JSONB, nullable=True)
  83. tool_call_prices = Column(JSONB, nullable=True)
  84. raw_prices = Column(JSONB, nullable=True) # 原始价格结构(含折扣)
  85. source_url = Column(Text, nullable=True)
  86. crawled_at = Column(DateTime, nullable=True)
  87. # ===== 分类(多分类数组) =====
  88. categories = Column(ARRAY(Integer), nullable=False, default=[])
  89. # ===== 管理员配置字段 =====
  90. display_name = Column(String(255), nullable=True)
  91. img = Column(Text, nullable=True)
  92. supplier = Column(String(100), nullable=False, default="Qwen")
  93. tag1 = Column(String(100), nullable=True)
  94. tag2 = Column(String(100), nullable=True)
  95. keywords = Column(Text, nullable=True)
  96. custom_description = Column(Text, nullable=True)
  97. is_featured = Column(Boolean, nullable=False, default=False)
  98. is_show_enabled = Column(Boolean, nullable=False, default=True)
  99. is_api_enabled = Column(Boolean, nullable=False, default=False)
  100. is_search = Column(Boolean, nullable=False, default=False)
  101. is_thinking = Column(Boolean, nullable=False, default=False)
  102. source_keys = Column(ARRAY(Text), nullable=True)
  103. normalized_keys = Column(ARRAY(Text), nullable=True)
  104. # ===== 爬虫扩展字段 =====
  105. group_name = Column(String(100), nullable=True, comment="模型分组名称,来自爬虫 group_name 字段")
  106. encrypted_api_key = Column(Text, nullable=True, comment="加密后的 API Key,来自爬虫 api_key 字段")
  107. # ===== 本地模型字段 =====
  108. is_local = Column(Boolean, nullable=False, default=False)
  109. user_id = Column(String(50), ForeignKey("aigcspace.users.id"), nullable=True)
  110. base_url = Column(String(500), nullable=True)
  111. local_api_key = Column(String(500), nullable=True)
  112. visibility = Column(String(20), nullable=True, default="user")
  113. created_at = Column(DateTime, nullable=False, server_default=func.now())
  114. updated_at = Column(DateTime, nullable=False, server_default=func.now(), onupdate=func.now())
  115. prices = relationship(
  116. "ModelPriceNew",
  117. primaryjoin="and_(ModelNew.model_code == foreign(ModelPriceNew.model_code), ModelPriceNew.is_active == True)",
  118. lazy="select",
  119. viewonly=True,
  120. order_by="ModelPriceNew.label"
  121. )
  122. __table_args__ = (
  123. Index("idx_models_new_categories", "categories", postgresql_using="gin"),
  124. Index("idx_models_new_supplier", "supplier"),
  125. Index("idx_models_new_is_show_enabled", "is_show_enabled"),
  126. Index("idx_models_new_is_api_enabled", "is_api_enabled"),
  127. Index("idx_models_new_is_local", "is_local"),
  128. Index("idx_models_new_crawled_at", "crawled_at"),
  129. {"schema": "aigcspace", "comment": "模型主表"},
  130. )
  131. def has_category(self, cat: ModelCategory) -> bool:
  132. """判断模型是否属于某个分类"""
  133. return int(cat) in (self.categories or [])
  134. def __repr__(self):
  135. return f"<ModelNew(id={self.id}, model_code='{self.model_code}')>"
  136. class ModelPriceNew(Base):
  137. """model_price_new 价格表"""
  138. __tablename__ = "model_price_new"
  139. id = Column(Integer, primary_key=True, autoincrement=True)
  140. model_code = Column(String(200), nullable=False)
  141. label = Column(String(200), nullable=False)
  142. tier_min = Column(Numeric(20, 2), nullable=True)
  143. tier_max = Column(Numeric(20, 2), nullable=True)
  144. tier_unit = Column(String(50), nullable=True)
  145. input_price_original = Column(Numeric(20, 8), nullable=False, default=0)
  146. output_price_original = Column(Numeric(20, 8), nullable=False, default=0)
  147. discount_rate = Column(Numeric(5, 4), nullable=False, default=1)
  148. discount_label = Column(String(20), nullable=True) # 如 "1折"、"5折"
  149. input_price_discounted = Column(Numeric(20, 8), nullable=False, default=0)
  150. output_price_discounted = Column(Numeric(20, 8), nullable=False, default=0)
  151. currency = Column(String(10), nullable=False, default="CNY")
  152. unit = Column(String(100), nullable=False)
  153. display_multiplier = Column(Integer, nullable=False, default=1)
  154. source_url = Column(Text, nullable=True)
  155. crawled_at = Column(DateTime, nullable=False)
  156. is_active = Column(Boolean, nullable=False, default=True)
  157. created_at = Column(DateTime, nullable=False, server_default=func.now())
  158. updated_at = Column(DateTime, nullable=False, server_default=func.now(), onupdate=func.now())
  159. __table_args__ = (
  160. Index("idx_model_price_model_code", "model_code"),
  161. Index("idx_model_price_label", "label"),
  162. Index("idx_model_price_active", "model_code", "is_active", postgresql_where="is_active = true"),
  163. {"schema": "aigcspace", "comment": "价格表"},
  164. )
  165. def __repr__(self):
  166. return f"<ModelPriceNew(id={self.id}, model_code='{self.model_code}', label='{self.label}')>"
  167. class CrawlerSyncLog(Base):
  168. """crawler_sync_log 爬虫同步版本记录"""
  169. __tablename__ = "crawler_sync_log"
  170. id = Column(Integer, primary_key=True, autoincrement=True)
  171. crawler_version = Column(Integer, nullable=False)
  172. synced_at = Column(DateTime, nullable=False, server_default=func.now())
  173. model_count = Column(Integer, nullable=False, default=0)
  174. price_count = Column(Integer, nullable=False, default=0)
  175. status = Column(String(20), nullable=False, default="success")
  176. error_message = Column(Text, nullable=True)
  177. __table_args__ = (
  178. Index("idx_crawler_sync_log_version", "crawler_version"),
  179. {"schema": "aigcspace", "comment": "爬虫同步版本记录"},
  180. )
  181. def __repr__(self):
  182. return f"<CrawlerSyncLog(id={self.id}, version={self.crawler_version}, status='{self.status}')>"