model_adapters.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383
  1. """
  2. 模型适配器模块
  3. 为不同的模型提供商提供统一的接口适配,实现:
  4. 1. 输入参数转换(OpenAI 格式 -> 提供商格式)
  5. 2. 输出结果转换(提供商格式 -> OpenAI 格式)
  6. 3. 自动识别提供商并应用相应的适配规则
  7. """
  8. from typing import Dict, Any, Optional, List
  9. from enum import Enum
  10. class ModelProvider(str, Enum):
  11. """模型提供商枚举"""
  12. OPENAI = "openai"
  13. SILICONFLOW = "siliconflow" # 硅基流动
  14. DASHSCOPE = "dashscope" # 阿里云百炼
  15. ZHIPU = "zhipu" # 智谱AI
  16. MOONSHOT = "moonshot" # 月之暗面
  17. DEEPSEEK = "deepseek" # DeepSeek
  18. GENERIC = "generic" # 通用 OpenAI 兼容
  19. class BaseAdapter:
  20. """基础适配器"""
  21. @staticmethod
  22. def detect_provider(base_url: str, model_name: str) -> ModelProvider:
  23. """
  24. 根据 base_url 和 model_name 自动检测提供商
  25. Args:
  26. base_url: API 基础 URL
  27. model_name: 模型名称
  28. Returns:
  29. ModelProvider: 检测到的提供商
  30. """
  31. base_url_lower = base_url.lower()
  32. if "siliconflow" in base_url_lower:
  33. return ModelProvider.SILICONFLOW
  34. elif "dashscope" in base_url_lower or "aliyun" in base_url_lower:
  35. return ModelProvider.DASHSCOPE
  36. elif "zhipuai" in base_url_lower:
  37. return ModelProvider.ZHIPU
  38. elif "moonshot" in base_url_lower:
  39. return ModelProvider.MOONSHOT
  40. elif "deepseek" in base_url_lower:
  41. return ModelProvider.DEEPSEEK
  42. elif "openai" in base_url_lower:
  43. return ModelProvider.OPENAI
  44. else:
  45. return ModelProvider.GENERIC
  46. class ChatAdapter:
  47. """对话接口适配器"""
  48. @staticmethod
  49. def adapt_request(provider: ModelProvider, openai_request: Dict[str, Any]) -> Dict[str, Any]:
  50. """
  51. 将 OpenAI 格式的请求转换为提供商格式
  52. Args:
  53. provider: 提供商类型
  54. openai_request: OpenAI 格式的请求
  55. Returns:
  56. 适配后的请求
  57. """
  58. # 大多数提供商都兼容 OpenAI 格式,直接返回
  59. if provider in [ModelProvider.GENERIC, ModelProvider.OPENAI, ModelProvider.SILICONFLOW,
  60. ModelProvider.MOONSHOT, ModelProvider.DEEPSEEK]:
  61. return openai_request
  62. # 特殊处理
  63. if provider == ModelProvider.DASHSCOPE:
  64. # 阿里云百炼的特殊参数
  65. adapted = openai_request.copy()
  66. # 可以在这里添加特殊处理
  67. return adapted
  68. return openai_request
  69. @staticmethod
  70. def adapt_response(provider: ModelProvider, provider_response: Dict[str, Any]) -> Dict[str, Any]:
  71. """
  72. 将提供商格式的响应转换为 OpenAI 格式
  73. Args:
  74. provider: 提供商类型
  75. provider_response: 提供商格式的响应
  76. Returns:
  77. OpenAI 格式的响应
  78. """
  79. # 大多数提供商已经返回 OpenAI 格式
  80. return provider_response
  81. class ImageAdapter:
  82. """图像生成接口适配器"""
  83. @staticmethod
  84. def adapt_request(provider: ModelProvider, openai_request: Dict[str, Any]) -> Dict[str, Any]:
  85. """
  86. 适配图像生成请求
  87. OpenAI 标准参数:
  88. - model: 模型名称
  89. - prompt: 提示词
  90. - n: 生成数量 (1-10)
  91. - size: 尺寸 "1024x1024", "1792x1024", "1024x1792"
  92. - quality: "standard" | "hd"
  93. - style: "vivid" | "natural"
  94. - response_format: "url" | "b64_json"
  95. """
  96. if provider == ModelProvider.SILICONFLOW:
  97. # 硅基流动兼容 OpenAI 格式
  98. return openai_request
  99. if provider == ModelProvider.DASHSCOPE:
  100. # 阿里云通义万相的参数格式不同
  101. adapted = {
  102. "model": openai_request.get("model"),
  103. "input": {
  104. "prompt": openai_request.get("prompt")
  105. },
  106. "parameters": {
  107. "n": openai_request.get("n", 1),
  108. "size": openai_request.get("size", "1024*1024").replace("x", "*")
  109. }
  110. }
  111. return adapted
  112. return openai_request
  113. @staticmethod
  114. def adapt_response(provider: ModelProvider, provider_response: Dict[str, Any]) -> Dict[str, Any]:
  115. """
  116. 适配图像生成响应
  117. OpenAI 标准响应:
  118. {
  119. "created": 1234567890,
  120. "data": [
  121. {"url": "https://...", "revised_prompt": "..."}
  122. ]
  123. }
  124. """
  125. if provider == ModelProvider.DASHSCOPE:
  126. # 阿里云格式转换
  127. output = provider_response.get("output", {})
  128. results = output.get("results", [])
  129. return {
  130. "created": provider_response.get("request_id", 0),
  131. "data": [
  132. {"url": item.get("url"), "revised_prompt": None}
  133. for item in results
  134. ]
  135. }
  136. # 默认已经是 OpenAI 格式
  137. return provider_response
  138. class AudioAdapter:
  139. """音频接口适配器"""
  140. @staticmethod
  141. def adapt_tts_request(provider: ModelProvider, openai_request: Dict[str, Any]) -> Dict[str, Any]:
  142. """
  143. 适配 TTS 请求
  144. OpenAI 标准参数:
  145. - model: "tts-1" | "tts-1-hd"
  146. - input: 文本内容
  147. - voice: "alloy" | "echo" | "fable" | "onyx" | "nova" | "shimmer"
  148. - response_format: "mp3" | "opus" | "aac" | "flac" | "wav" | "pcm"
  149. - speed: 0.25 - 4.0
  150. """
  151. if provider == ModelProvider.SILICONFLOW:
  152. # 硅基流动兼容 OpenAI 格式
  153. return openai_request
  154. if provider == ModelProvider.DASHSCOPE:
  155. # 阿里云 CosyVoice 格式
  156. adapted = {
  157. "model": openai_request.get("model"),
  158. "input": {
  159. "text": openai_request.get("input")
  160. },
  161. "parameters": {
  162. "voice": openai_request.get("voice", "longxiaochun"),
  163. "format": openai_request.get("response_format", "mp3"),
  164. "rate": openai_request.get("speed", 1.0)
  165. }
  166. }
  167. return adapted
  168. return openai_request
  169. @staticmethod
  170. def adapt_stt_request(provider: ModelProvider, openai_request: Dict[str, Any]) -> Dict[str, Any]:
  171. """
  172. 适配 STT 请求
  173. OpenAI 标准参数:
  174. - file: 音频文件
  175. - model: "whisper-1"
  176. - language: ISO-639-1 语言代码
  177. - prompt: 可选的提示文本
  178. - response_format: "json" | "text" | "srt" | "verbose_json" | "vtt"
  179. - temperature: 0-1
  180. """
  181. if provider == ModelProvider.SILICONFLOW:
  182. return openai_request
  183. if provider == ModelProvider.DASHSCOPE:
  184. # 阿里云 Paraformer 格式
  185. adapted = {
  186. "model": openai_request.get("model"),
  187. "input": {
  188. "audio": openai_request.get("file")
  189. },
  190. "parameters": {
  191. "language": openai_request.get("language", "zh"),
  192. "format": openai_request.get("response_format", "json")
  193. }
  194. }
  195. return adapted
  196. return openai_request
  197. @staticmethod
  198. def is_audio_response(response_content_type: str) -> bool:
  199. """
  200. 判断响应是否是音频数据
  201. Args:
  202. response_content_type: Content-Type 头
  203. Returns:
  204. 是否是音频数据
  205. """
  206. audio_types = ["audio/", "application/octet-stream"]
  207. return any(t in response_content_type.lower() for t in audio_types)
  208. class VideoAdapter:
  209. """视频生成接口适配器"""
  210. @staticmethod
  211. def adapt_request(provider: ModelProvider, openai_request: Dict[str, Any]) -> Dict[str, Any]:
  212. """
  213. 适配视频生成请求
  214. 参数:
  215. - model: 模型名称
  216. - prompt: 提示词
  217. - size: "720P" | "1080P"
  218. - duration: 视频时长(秒)
  219. """
  220. if provider == ModelProvider.SILICONFLOW:
  221. return openai_request
  222. if provider == ModelProvider.DASHSCOPE:
  223. # 阿里云视频生成格式
  224. adapted = {
  225. "model": openai_request.get("model"),
  226. "input": {
  227. "prompt": openai_request.get("prompt")
  228. },
  229. "parameters": {
  230. "resolution": openai_request.get("size", "720P"),
  231. "duration": openai_request.get("duration", 5)
  232. }
  233. }
  234. return adapted
  235. return openai_request
  236. class EmbeddingAdapter:
  237. """向量嵌入接口适配器"""
  238. @staticmethod
  239. def adapt_request(provider: ModelProvider, openai_request: Dict[str, Any]) -> Dict[str, Any]:
  240. """
  241. 适配 Embedding 请求
  242. OpenAI 标准参数:
  243. - model: 模型名称
  244. - input: 文本或文本数组
  245. - encoding_format: "float" | "base64"
  246. - dimensions: 输出维度
  247. """
  248. if provider == ModelProvider.DASHSCOPE:
  249. # 阿里云格式
  250. input_text = openai_request.get("input")
  251. if isinstance(input_text, str):
  252. input_text = [input_text]
  253. adapted = {
  254. "model": openai_request.get("model"),
  255. "input": {
  256. "texts": input_text
  257. },
  258. "parameters": {
  259. "text_type": "query"
  260. }
  261. }
  262. if openai_request.get("dimensions"):
  263. adapted["parameters"]["dimension"] = openai_request["dimensions"]
  264. return adapted
  265. return openai_request
  266. @staticmethod
  267. def adapt_response(provider: ModelProvider, provider_response: Dict[str, Any]) -> Dict[str, Any]:
  268. """
  269. 适配 Embedding 响应
  270. OpenAI 标准响应:
  271. {
  272. "object": "list",
  273. "data": [
  274. {"object": "embedding", "embedding": [...], "index": 0}
  275. ],
  276. "model": "...",
  277. "usage": {"prompt_tokens": 8, "total_tokens": 8}
  278. }
  279. """
  280. if provider == ModelProvider.DASHSCOPE:
  281. # 阿里云格式转换
  282. output = provider_response.get("output", {})
  283. embeddings = output.get("embeddings", [])
  284. usage = provider_response.get("usage", {})
  285. return {
  286. "object": "list",
  287. "data": [
  288. {
  289. "object": "embedding",
  290. "embedding": item.get("embedding", []),
  291. "index": idx
  292. }
  293. for idx, item in enumerate(embeddings)
  294. ],
  295. "model": provider_response.get("model", ""),
  296. "usage": {
  297. "prompt_tokens": usage.get("input_tokens", 0),
  298. "total_tokens": usage.get("total_tokens", 0)
  299. }
  300. }
  301. return provider_response
  302. # 导出适配器工厂
  303. def get_adapter(endpoint_type: str):
  304. """
  305. 获取指定类型的适配器
  306. Args:
  307. endpoint_type: 端点类型 ("chat", "image", "audio_tts", "audio_stt", "video", "embedding")
  308. Returns:
  309. 对应的适配器类
  310. """
  311. adapters = {
  312. "chat": ChatAdapter,
  313. "image": ImageAdapter,
  314. "audio_tts": AudioAdapter,
  315. "audio_stt": AudioAdapter,
  316. "video": VideoAdapter,
  317. "embedding": EmbeddingAdapter,
  318. }
  319. return adapters.get(endpoint_type, BaseAdapter)