local_model.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. """
  2. 本地模型数据传输对象定义
  3. """
  4. from datetime import datetime
  5. from typing import List, Optional
  6. from pydantic import BaseModel, ConfigDict, Field
  7. class LocalModelCreate(BaseModel):
  8. """创建本地模型请求"""
  9. name: str = Field(..., min_length=1, max_length=100)
  10. supplier: str = Field("Custom", min_length=1, max_length=100)
  11. base_url: str = Field(..., min_length=1, max_length=500)
  12. api_key: Optional[str] = Field(None, max_length=500)
  13. visibility: Optional[str] = Field('user')
  14. categories: Optional[List[int]] = Field([0], description="模型分类数组: 0=LLM,1=多模态,2=TTS,3=STT,4=图像生成,5=视频生成,6=图像编辑,7=Embedding,8=Rerank")
  15. class LocalModelUpdate(BaseModel):
  16. """更新本地模型请求"""
  17. name: Optional[str] = Field(None, min_length=1, max_length=100)
  18. supplier: Optional[str] = Field(None, min_length=1, max_length=100)
  19. base_url: Optional[str] = Field(None, min_length=1, max_length=500)
  20. api_key: Optional[str] = Field(None, max_length=500)
  21. visibility: Optional[str] = Field(None)
  22. categories: Optional[List[int]] = Field(None, description="模型分类数组")
  23. class LocalModelResponse(BaseModel):
  24. """本地模型响应"""
  25. id: int
  26. name: str
  27. supplier: str = Field("Custom")
  28. base_url: str = Field("")
  29. has_api_key: bool = Field(False)
  30. visibility: str = Field('user')
  31. categories: List[int] = Field([0], description="模型分类数组")
  32. category: int = Field(0, description="主分类(兼容字段,取 categories[0])")
  33. created_at: datetime
  34. updated_at: datetime
  35. model_config = ConfigDict(from_attributes=True)
  36. @classmethod
  37. def from_model(cls, model) -> "LocalModelResponse":
  38. cats = model.categories or [0]
  39. return cls(
  40. id=model.id,
  41. name=model.name,
  42. supplier=model.supplier or "Custom",
  43. base_url=model.base_url or "",
  44. has_api_key=bool(model.local_api_key),
  45. visibility=model.visibility or 'user',
  46. categories=cats,
  47. category=cats[0] if cats else 0,
  48. created_at=model.created_at,
  49. updated_at=model.updated_at
  50. )
  51. class ConnectionTestRequest(BaseModel):
  52. """连接测试请求"""
  53. base_url: str = Field(..., min_length=1, max_length=500)
  54. api_key: Optional[str] = Field(None, max_length=500)
  55. model_name: Optional[str] = Field(None, max_length=100)
  56. categories: Optional[List[int]] = Field(None)
  57. class ConnectionTestResponse(BaseModel):
  58. """连接测试响应"""
  59. success: bool
  60. message: str