| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383 |
- """
- 模型适配器模块
- 为不同的模型提供商提供统一的接口适配,实现:
- 1. 输入参数转换(OpenAI 格式 -> 提供商格式)
- 2. 输出结果转换(提供商格式 -> OpenAI 格式)
- 3. 自动识别提供商并应用相应的适配规则
- """
- from typing import Dict, Any, Optional, List
- from enum import Enum
- class ModelProvider(str, Enum):
- """模型提供商枚举"""
- OPENAI = "openai"
- SILICONFLOW = "siliconflow" # 硅基流动
- DASHSCOPE = "dashscope" # 阿里云百炼
- ZHIPU = "zhipu" # 智谱AI
- MOONSHOT = "moonshot" # 月之暗面
- DEEPSEEK = "deepseek" # DeepSeek
- GENERIC = "generic" # 通用 OpenAI 兼容
- class BaseAdapter:
- """基础适配器"""
-
- @staticmethod
- def detect_provider(base_url: str, model_name: str) -> ModelProvider:
- """
- 根据 base_url 和 model_name 自动检测提供商
-
- Args:
- base_url: API 基础 URL
- model_name: 模型名称
-
- Returns:
- ModelProvider: 检测到的提供商
- """
- base_url_lower = base_url.lower()
-
- if "siliconflow" in base_url_lower:
- return ModelProvider.SILICONFLOW
- elif "dashscope" in base_url_lower or "aliyun" in base_url_lower:
- return ModelProvider.DASHSCOPE
- elif "zhipuai" in base_url_lower:
- return ModelProvider.ZHIPU
- elif "moonshot" in base_url_lower:
- return ModelProvider.MOONSHOT
- elif "deepseek" in base_url_lower:
- return ModelProvider.DEEPSEEK
- elif "openai" in base_url_lower:
- return ModelProvider.OPENAI
- else:
- return ModelProvider.GENERIC
- class ChatAdapter:
- """对话接口适配器"""
-
- @staticmethod
- def adapt_request(provider: ModelProvider, openai_request: Dict[str, Any]) -> Dict[str, Any]:
- """
- 将 OpenAI 格式的请求转换为提供商格式
-
- Args:
- provider: 提供商类型
- openai_request: OpenAI 格式的请求
-
- Returns:
- 适配后的请求
- """
- # 大多数提供商都兼容 OpenAI 格式,直接返回
- if provider in [ModelProvider.GENERIC, ModelProvider.OPENAI, ModelProvider.SILICONFLOW,
- ModelProvider.MOONSHOT, ModelProvider.DEEPSEEK]:
- return openai_request
-
- # 特殊处理
- if provider == ModelProvider.DASHSCOPE:
- # 阿里云百炼的特殊参数
- adapted = openai_request.copy()
- # 可以在这里添加特殊处理
- return adapted
-
- return openai_request
-
- @staticmethod
- def adapt_response(provider: ModelProvider, provider_response: Dict[str, Any]) -> Dict[str, Any]:
- """
- 将提供商格式的响应转换为 OpenAI 格式
-
- Args:
- provider: 提供商类型
- provider_response: 提供商格式的响应
-
- Returns:
- OpenAI 格式的响应
- """
- # 大多数提供商已经返回 OpenAI 格式
- return provider_response
- class ImageAdapter:
- """图像生成接口适配器"""
-
- @staticmethod
- def adapt_request(provider: ModelProvider, openai_request: Dict[str, Any]) -> Dict[str, Any]:
- """
- 适配图像生成请求
-
- OpenAI 标准参数:
- - model: 模型名称
- - prompt: 提示词
- - n: 生成数量 (1-10)
- - size: 尺寸 "1024x1024", "1792x1024", "1024x1792"
- - quality: "standard" | "hd"
- - style: "vivid" | "natural"
- - response_format: "url" | "b64_json"
- """
- if provider == ModelProvider.SILICONFLOW:
- # 硅基流动兼容 OpenAI 格式
- return openai_request
-
- if provider == ModelProvider.DASHSCOPE:
- # 阿里云通义万相的参数格式不同
- adapted = {
- "model": openai_request.get("model"),
- "input": {
- "prompt": openai_request.get("prompt")
- },
- "parameters": {
- "n": openai_request.get("n", 1),
- "size": openai_request.get("size", "1024*1024").replace("x", "*")
- }
- }
- return adapted
-
- return openai_request
-
- @staticmethod
- def adapt_response(provider: ModelProvider, provider_response: Dict[str, Any]) -> Dict[str, Any]:
- """
- 适配图像生成响应
-
- OpenAI 标准响应:
- {
- "created": 1234567890,
- "data": [
- {"url": "https://...", "revised_prompt": "..."}
- ]
- }
- """
- if provider == ModelProvider.DASHSCOPE:
- # 阿里云格式转换
- output = provider_response.get("output", {})
- results = output.get("results", [])
-
- return {
- "created": provider_response.get("request_id", 0),
- "data": [
- {"url": item.get("url"), "revised_prompt": None}
- for item in results
- ]
- }
-
- # 默认已经是 OpenAI 格式
- return provider_response
- class AudioAdapter:
- """音频接口适配器"""
-
- @staticmethod
- def adapt_tts_request(provider: ModelProvider, openai_request: Dict[str, Any]) -> Dict[str, Any]:
- """
- 适配 TTS 请求
-
- OpenAI 标准参数:
- - model: "tts-1" | "tts-1-hd"
- - input: 文本内容
- - voice: "alloy" | "echo" | "fable" | "onyx" | "nova" | "shimmer"
- - response_format: "mp3" | "opus" | "aac" | "flac" | "wav" | "pcm"
- - speed: 0.25 - 4.0
- """
- if provider == ModelProvider.SILICONFLOW:
- # 硅基流动兼容 OpenAI 格式
- return openai_request
-
- if provider == ModelProvider.DASHSCOPE:
- # 阿里云 CosyVoice 格式
- adapted = {
- "model": openai_request.get("model"),
- "input": {
- "text": openai_request.get("input")
- },
- "parameters": {
- "voice": openai_request.get("voice", "longxiaochun"),
- "format": openai_request.get("response_format", "mp3"),
- "rate": openai_request.get("speed", 1.0)
- }
- }
- return adapted
-
- return openai_request
-
- @staticmethod
- def adapt_stt_request(provider: ModelProvider, openai_request: Dict[str, Any]) -> Dict[str, Any]:
- """
- 适配 STT 请求
-
- OpenAI 标准参数:
- - file: 音频文件
- - model: "whisper-1"
- - language: ISO-639-1 语言代码
- - prompt: 可选的提示文本
- - response_format: "json" | "text" | "srt" | "verbose_json" | "vtt"
- - temperature: 0-1
- """
- if provider == ModelProvider.SILICONFLOW:
- return openai_request
-
- if provider == ModelProvider.DASHSCOPE:
- # 阿里云 Paraformer 格式
- adapted = {
- "model": openai_request.get("model"),
- "input": {
- "audio": openai_request.get("file")
- },
- "parameters": {
- "language": openai_request.get("language", "zh"),
- "format": openai_request.get("response_format", "json")
- }
- }
- return adapted
-
- return openai_request
-
- @staticmethod
- def is_audio_response(response_content_type: str) -> bool:
- """
- 判断响应是否是音频数据
-
- Args:
- response_content_type: Content-Type 头
-
- Returns:
- 是否是音频数据
- """
- audio_types = ["audio/", "application/octet-stream"]
- return any(t in response_content_type.lower() for t in audio_types)
- class VideoAdapter:
- """视频生成接口适配器"""
-
- @staticmethod
- def adapt_request(provider: ModelProvider, openai_request: Dict[str, Any]) -> Dict[str, Any]:
- """
- 适配视频生成请求
-
- 参数:
- - model: 模型名称
- - prompt: 提示词
- - size: "720P" | "1080P"
- - duration: 视频时长(秒)
- """
- if provider == ModelProvider.SILICONFLOW:
- return openai_request
-
- if provider == ModelProvider.DASHSCOPE:
- # 阿里云视频生成格式
- adapted = {
- "model": openai_request.get("model"),
- "input": {
- "prompt": openai_request.get("prompt")
- },
- "parameters": {
- "resolution": openai_request.get("size", "720P"),
- "duration": openai_request.get("duration", 5)
- }
- }
- return adapted
-
- return openai_request
- class EmbeddingAdapter:
- """向量嵌入接口适配器"""
-
- @staticmethod
- def adapt_request(provider: ModelProvider, openai_request: Dict[str, Any]) -> Dict[str, Any]:
- """
- 适配 Embedding 请求
-
- OpenAI 标准参数:
- - model: 模型名称
- - input: 文本或文本数组
- - encoding_format: "float" | "base64"
- - dimensions: 输出维度
- """
- if provider == ModelProvider.DASHSCOPE:
- # 阿里云格式
- input_text = openai_request.get("input")
- if isinstance(input_text, str):
- input_text = [input_text]
-
- adapted = {
- "model": openai_request.get("model"),
- "input": {
- "texts": input_text
- },
- "parameters": {
- "text_type": "query"
- }
- }
-
- if openai_request.get("dimensions"):
- adapted["parameters"]["dimension"] = openai_request["dimensions"]
-
- return adapted
-
- return openai_request
-
- @staticmethod
- def adapt_response(provider: ModelProvider, provider_response: Dict[str, Any]) -> Dict[str, Any]:
- """
- 适配 Embedding 响应
-
- OpenAI 标准响应:
- {
- "object": "list",
- "data": [
- {"object": "embedding", "embedding": [...], "index": 0}
- ],
- "model": "...",
- "usage": {"prompt_tokens": 8, "total_tokens": 8}
- }
- """
- if provider == ModelProvider.DASHSCOPE:
- # 阿里云格式转换
- output = provider_response.get("output", {})
- embeddings = output.get("embeddings", [])
- usage = provider_response.get("usage", {})
-
- return {
- "object": "list",
- "data": [
- {
- "object": "embedding",
- "embedding": item.get("embedding", []),
- "index": idx
- }
- for idx, item in enumerate(embeddings)
- ],
- "model": provider_response.get("model", ""),
- "usage": {
- "prompt_tokens": usage.get("input_tokens", 0),
- "total_tokens": usage.get("total_tokens", 0)
- }
- }
-
- return provider_response
- # 导出适配器工厂
- def get_adapter(endpoint_type: str):
- """
- 获取指定类型的适配器
-
- Args:
- endpoint_type: 端点类型 ("chat", "image", "audio_tts", "audio_stt", "video", "embedding")
-
- Returns:
- 对应的适配器类
- """
- adapters = {
- "chat": ChatAdapter,
- "image": ImageAdapter,
- "audio_tts": AudioAdapter,
- "audio_stt": AudioAdapter,
- "video": VideoAdapter,
- "embedding": EmbeddingAdapter,
- }
- return adapters.get(endpoint_type, BaseAdapter)
|