| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232 |
- """
- 模型数据ORM定义(新表结构)
- 对应 models_new、model_price_new、crawler_sync_log 三张表
- """
- from enum import IntEnum
- from sqlalchemy import Boolean, Column, Integer, String, Text, DateTime, Index, Numeric, ForeignKey
- from sqlalchemy.dialects.postgresql import ARRAY, JSONB
- from sqlalchemy.sql import func
- from sqlalchemy.orm import relationship
- from app.database import Base
- class ModelCategory(IntEnum):
- """模型分类枚举"""
- LLM = 0
- MULTIMODAL = 1
- TTS = 2
- STT = 3
- IMAGE_GEN = 4
- VIDEO_GEN = 5
- IMAGE_EDIT = 6
- EMBEDDING = 7
- RERANK = 8
- def infer_categories(input_modalities: list, output_modalities: list, model_code: str = "", display_tags: list = None) -> list:
- """
- 根据 input/output_modalities、model_code、display_tags 推断模型分类数组。
- 规则优先级:
- 1. model_code 关键词(最可靠)
- 2. output_modalities 主输出类型
- 3. input_modalities 辅助判断
- 4. display_tags 辅助修正
- """
- inp = set(m.lower() for m in (input_modalities or []))
- out = set(m.lower() for m in (output_modalities or []))
- code = (model_code or "").lower()
- tags = set(t.lower() for t in (display_tags or []))
- cats = set()
- # ── 1. 视频生成(输出 Video,或 model_code 含视频关键词)──────────────
- if 'video' in out or 'v2v' in code or 'i2v' in code or 't2v' in code or 's2v' in code:
- cats.add(ModelCategory.VIDEO_GEN)
- # ── 2. 图像翻译(model_code 含 mt-image)
- # 这类模型 input/output 都是 Image,但不是图生图,跳过图像分类
- is_image_translation = 'mt-image' in code
- # ── 3. 图像生成 / 图像编辑(输出 Image,且不是视频模型,且不是图像翻译)
- if 'image' in out and 'video' not in out and not is_image_translation:
- if 'text' in inp:
- cats.add(ModelCategory.IMAGE_GEN)
- if 'image' in inp:
- cats.add(ModelCategory.IMAGE_EDIT)
- # ── 4. TTS(输出 Audio,输入 Text,且不是视频模型,且不是全模态对话模型)
- # 全模态对话模型特征:输入同时有 Text+Image+Audio(或 Video)
- is_omni = 'audio' in inp and ('image' in inp or 'video' in inp)
- if 'audio' in out and 'text' in inp and 'video' not in out and not is_omni:
- cats.add(ModelCategory.TTS)
- # ── 5. STT(输入 Audio,输出 Text,且不是视频模型,且不是全模态对话模型)
- if 'audio' in inp and 'text' in out and 'video' not in out and 'image' not in out and not is_omni:
- cats.add(ModelCategory.STT)
- # ── 6. Embedding(model_code 或 tags 含关键词)──────────────────────────
- if any(k in code for k in ('embed', 'embedding')) or '向量' in tags or 'embedding' in tags:
- cats.add(ModelCategory.EMBEDDING)
- # ── 7. Rerank(model_code 或 tags 含关键词)─────────────────────────────
- if any(k in code for k in ('rerank', 'reranker')) or '重排' in tags or 'rerank' in tags:
- cats.add(ModelCategory.RERANK)
- # ── 8. 多模态(输入有 Image 或 Video,输出有 Text,且不是图像生成/视频生成)
- if ('image' in inp or 'video' in inp) and 'text' in out and not cats.intersection({
- ModelCategory.IMAGE_GEN, ModelCategory.IMAGE_EDIT, ModelCategory.VIDEO_GEN
- }):
- cats.add(ModelCategory.MULTIMODAL)
- # ── 9. 兜底:纯文本 → LLM ────────────────────────────────────────────────
- if not cats:
- cats.add(ModelCategory.LLM)
- return sorted(int(c) for c in cats)
- class ModelNew(Base):
- """models_new 主表"""
- __tablename__ = "models_new"
- id = Column(Integer, primary_key=True, autoincrement=True)
- model_code = Column(String(200), nullable=False, unique=True, comment="模型唯一标识")
- # ===== 爬虫字段 =====
- description = Column(Text, nullable=True)
- display_tags = Column(ARRAY(Text), nullable=True)
- input_modalities = Column(ARRAY(Text), nullable=True)
- output_modalities = Column(ARRAY(Text), nullable=True)
- features = Column(JSONB, nullable=True)
- rate_limits = Column(JSONB, nullable=True)
- tool_call_prices = Column(JSONB, nullable=True)
- raw_prices = Column(JSONB, nullable=True) # 原始价格结构(含折扣)
- source_url = Column(Text, nullable=True)
- crawled_at = Column(DateTime, nullable=True)
- # ===== 分类(多分类数组) =====
- categories = Column(ARRAY(Integer), nullable=False, default=[])
- # ===== 管理员配置字段 =====
- display_name = Column(String(255), nullable=True)
- img = Column(Text, nullable=True)
- supplier = Column(String(100), nullable=False, default="Qwen")
- tag1 = Column(String(100), nullable=True)
- tag2 = Column(String(100), nullable=True)
- keywords = Column(Text, nullable=True)
- custom_description = Column(Text, nullable=True)
- is_featured = Column(Boolean, nullable=False, default=False)
- is_show_enabled = Column(Boolean, nullable=False, default=True)
- is_api_enabled = Column(Boolean, nullable=False, default=False)
- is_search = Column(Boolean, nullable=False, default=False)
- is_thinking = Column(Boolean, nullable=False, default=False)
- source_keys = Column(ARRAY(Text), nullable=True)
- normalized_keys = Column(ARRAY(Text), nullable=True)
- # ===== 爬虫扩展字段 =====
- group_name = Column(String(100), nullable=True, comment="模型分组名称,来自爬虫 group_name 字段")
- encrypted_api_key = Column(Text, nullable=True, comment="加密后的 API Key,来自爬虫 api_key 字段")
- # ===== 本地模型字段 =====
- is_local = Column(Boolean, nullable=False, default=False)
- user_id = Column(String(50), ForeignKey("aigcspace.users.id"), nullable=True)
- base_url = Column(String(500), nullable=True)
- local_api_key = Column(String(500), nullable=True)
- visibility = Column(String(20), nullable=True, default="user")
- created_at = Column(DateTime, nullable=False, server_default=func.now())
- updated_at = Column(DateTime, nullable=False, server_default=func.now(), onupdate=func.now())
- prices = relationship(
- "ModelPriceNew",
- primaryjoin="and_(ModelNew.model_code == foreign(ModelPriceNew.model_code), ModelPriceNew.is_active == True)",
- lazy="select",
- viewonly=True,
- order_by="ModelPriceNew.label"
- )
- __table_args__ = (
- Index("idx_models_new_categories", "categories", postgresql_using="gin"),
- Index("idx_models_new_supplier", "supplier"),
- Index("idx_models_new_is_show_enabled", "is_show_enabled"),
- Index("idx_models_new_is_api_enabled", "is_api_enabled"),
- Index("idx_models_new_is_local", "is_local"),
- Index("idx_models_new_crawled_at", "crawled_at"),
- {"schema": "aigcspace", "comment": "模型主表"},
- )
- def has_category(self, cat: ModelCategory) -> bool:
- """判断模型是否属于某个分类"""
- return int(cat) in (self.categories or [])
- def __repr__(self):
- return f"<ModelNew(id={self.id}, model_code='{self.model_code}')>"
- class ModelPriceNew(Base):
- """model_price_new 价格表"""
- __tablename__ = "model_price_new"
- id = Column(Integer, primary_key=True, autoincrement=True)
- model_code = Column(String(200), nullable=False)
- label = Column(String(200), nullable=False)
- tier_min = Column(Numeric(20, 2), nullable=True)
- tier_max = Column(Numeric(20, 2), nullable=True)
- tier_unit = Column(String(50), nullable=True)
- input_price_original = Column(Numeric(20, 8), nullable=False, default=0)
- output_price_original = Column(Numeric(20, 8), nullable=False, default=0)
- discount_rate = Column(Numeric(5, 4), nullable=False, default=1)
- discount_label = Column(String(20), nullable=True) # 如 "1折"、"5折"
- input_price_discounted = Column(Numeric(20, 8), nullable=False, default=0)
- output_price_discounted = Column(Numeric(20, 8), nullable=False, default=0)
- currency = Column(String(10), nullable=False, default="CNY")
- unit = Column(String(100), nullable=False)
- display_multiplier = Column(Integer, nullable=False, default=1)
- source_url = Column(Text, nullable=True)
- crawled_at = Column(DateTime, nullable=False)
- is_active = Column(Boolean, nullable=False, default=True)
- created_at = Column(DateTime, nullable=False, server_default=func.now())
- updated_at = Column(DateTime, nullable=False, server_default=func.now(), onupdate=func.now())
- __table_args__ = (
- Index("idx_model_price_model_code", "model_code"),
- Index("idx_model_price_label", "label"),
- Index("idx_model_price_active", "model_code", "is_active", postgresql_where="is_active = true"),
- {"schema": "aigcspace", "comment": "价格表"},
- )
- def __repr__(self):
- return f"<ModelPriceNew(id={self.id}, model_code='{self.model_code}', label='{self.label}')>"
- class CrawlerSyncLog(Base):
- """crawler_sync_log 爬虫同步版本记录"""
- __tablename__ = "crawler_sync_log"
- id = Column(Integer, primary_key=True, autoincrement=True)
- crawler_version = Column(Integer, nullable=False)
- synced_at = Column(DateTime, nullable=False, server_default=func.now())
- model_count = Column(Integer, nullable=False, default=0)
- price_count = Column(Integer, nullable=False, default=0)
- status = Column(String(20), nullable=False, default="success")
- error_message = Column(Text, nullable=True)
- __table_args__ = (
- Index("idx_crawler_sync_log_version", "crawler_version"),
- {"schema": "aigcspace", "comment": "爬虫同步版本记录"},
- )
- def __repr__(self):
- return f"<CrawlerSyncLog(id={self.id}, version={self.crawler_version}, status='{self.status}')>"
|