|
|
@@ -0,0 +1,372 @@
|
|
|
+"""
|
|
|
+LLM API客户端工具类
|
|
|
+支持异步并发调用多个LLM API请求
|
|
|
+"""
|
|
|
+
|
|
|
+from __future__ import annotations
|
|
|
+
|
|
|
+import asyncio
|
|
|
+import json
|
|
|
+from typing import Any, Dict, List, Optional
|
|
|
+from pathlib import Path
|
|
|
+
|
|
|
+try:
|
|
|
+ import aiohttp
|
|
|
+ HAS_AIOHTTP = True
|
|
|
+except ImportError:
|
|
|
+ HAS_AIOHTTP = False
|
|
|
+
|
|
|
+try:
|
|
|
+ import requests
|
|
|
+ HAS_REQUESTS = True
|
|
|
+except ImportError:
|
|
|
+ HAS_REQUESTS = False
|
|
|
+
|
|
|
+from ..config.provider import default_config_provider
|
|
|
+
|
|
|
+
|
|
|
+class LLMClient:
|
|
|
+ """LLM API客户端,支持异步并发调用"""
|
|
|
+
|
|
|
+ def __init__(self, config_provider=None):
|
|
|
+ """
|
|
|
+ 初始化LLM客户端
|
|
|
+
|
|
|
+ 参数:
|
|
|
+ config_provider: 配置提供者,如果为None则使用默认配置
|
|
|
+ """
|
|
|
+ self._cfg = config_provider or default_config_provider
|
|
|
+ self._load_config()
|
|
|
+
|
|
|
+ def _load_config(self):
|
|
|
+ """加载LLM API配置"""
|
|
|
+ # 加载llm_api.yaml配置
|
|
|
+ llm_api_path = Path(__file__).parent.parent / "config" / "llm_api.yaml"
|
|
|
+ import yaml
|
|
|
+
|
|
|
+ with open(llm_api_path, "r", encoding="utf-8") as f:
|
|
|
+ llm_config = yaml.safe_load(f) or {}
|
|
|
+
|
|
|
+ # 获取模型类型
|
|
|
+ self.model_type = llm_config.get("MODEL_TYPE", "qwen").lower()
|
|
|
+
|
|
|
+ # 获取模型配置
|
|
|
+ model_config = llm_config.get(self.model_type, {})
|
|
|
+
|
|
|
+ # 根据模型类型设置URL、模型ID和API Key
|
|
|
+ if self.model_type == "qwen":
|
|
|
+ self.api_url = model_config.get("QWEN_SERVER_URL", "").rstrip("/")
|
|
|
+ self.model_id = model_config.get("QWEN_MODEL_ID", "")
|
|
|
+ self.api_key = model_config.get("QWEN_API_KEY", "")
|
|
|
+ self.base_url = f"{self.api_url}/chat/completions"
|
|
|
+ elif self.model_type == "deepseek":
|
|
|
+ self.api_url = model_config.get("DEEPSEEK_SERVER_URL", "").rstrip("/")
|
|
|
+ self.model_id = model_config.get("DEEPSEEK_MODEL_ID", "")
|
|
|
+ self.api_key = model_config.get("DEEPSEEK_API_KEY", "")
|
|
|
+ self.base_url = f"{self.api_url}/chat/completions"
|
|
|
+ elif self.model_type == "doubao":
|
|
|
+ self.api_url = model_config.get("DOUBAO_SERVER_URL", "").rstrip("/")
|
|
|
+ self.model_id = model_config.get("DOUBAO_MODEL_ID", "")
|
|
|
+ self.api_key = model_config.get("DOUBAO_API_KEY", "")
|
|
|
+ self.base_url = f"{self.api_url}/chat/completions"
|
|
|
+ elif self.model_type == "gemini":
|
|
|
+ self.api_url = model_config.get("GEMINI_SERVER_URL", "").rstrip("/")
|
|
|
+ self.model_id = model_config.get("GEMINI_MODEL_ID", "")
|
|
|
+ self.api_key = model_config.get("GEMINI_API_KEY", "")
|
|
|
+ self.base_url = f"{self.api_url}/chat/completions"
|
|
|
+ else:
|
|
|
+ raise ValueError(f"不支持的模型类型: {self.model_type}")
|
|
|
+
|
|
|
+ # 获取通用配置
|
|
|
+ keywords_config = llm_config.get("keywords", {})
|
|
|
+ self.timeout = keywords_config.get("timeout", 30)
|
|
|
+ self.max_retries = keywords_config.get("max_retries", 2)
|
|
|
+ self.concurrent_workers = keywords_config.get("concurrent_workers", 20)
|
|
|
+ self.stream = keywords_config.get("stream", False)
|
|
|
+
|
|
|
+ request_payload = keywords_config.get("request_payload", {})
|
|
|
+ self.temperature = request_payload.get("temperature", 0.3)
|
|
|
+ self.max_tokens = request_payload.get("max_tokens", 1024)
|
|
|
+
|
|
|
+ async def _call_api_async(self, session: aiohttp.ClientSession, messages: List[Dict[str, str]]) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 异步调用LLM API
|
|
|
+
|
|
|
+ 参数:
|
|
|
+ session: aiohttp会话
|
|
|
+ messages: 消息列表
|
|
|
+
|
|
|
+ 返回:
|
|
|
+ API响应结果
|
|
|
+ """
|
|
|
+ headers = {
|
|
|
+ "Content-Type": "application/json",
|
|
|
+ "Authorization": f"Bearer {self.api_key}"
|
|
|
+ }
|
|
|
+
|
|
|
+ payload = {
|
|
|
+ "model": self.model_id,
|
|
|
+ "messages": messages,
|
|
|
+ "temperature": self.temperature,
|
|
|
+ "max_tokens": self.max_tokens,
|
|
|
+ "stream": self.stream
|
|
|
+ }
|
|
|
+
|
|
|
+ for attempt in range(self.max_retries):
|
|
|
+ try:
|
|
|
+ async with session.post(
|
|
|
+ self.base_url,
|
|
|
+ json=payload,
|
|
|
+ headers=headers,
|
|
|
+ timeout=aiohttp.ClientTimeout(total=self.timeout)
|
|
|
+ ) as response:
|
|
|
+ if response.status == 200:
|
|
|
+ result = await response.json()
|
|
|
+ return result
|
|
|
+ else:
|
|
|
+ error_text = await response.text()
|
|
|
+ if attempt < self.max_retries - 1:
|
|
|
+ await asyncio.sleep(1 * (attempt + 1)) # 指数退避
|
|
|
+ continue
|
|
|
+ raise Exception(f"API调用失败,状态码: {response.status}, 错误: {error_text}")
|
|
|
+ except asyncio.TimeoutError:
|
|
|
+ if attempt < self.max_retries - 1:
|
|
|
+ await asyncio.sleep(1 * (attempt + 1))
|
|
|
+ continue
|
|
|
+ raise Exception(f"API调用超时(超过{self.timeout}秒)")
|
|
|
+ except Exception as e:
|
|
|
+ if attempt < self.max_retries - 1:
|
|
|
+ await asyncio.sleep(1 * (attempt + 1))
|
|
|
+ continue
|
|
|
+ raise
|
|
|
+
|
|
|
+ raise Exception("API调用失败,已达到最大重试次数")
|
|
|
+
|
|
|
+ def _call_api_sync(self, messages: List[Dict[str, str]]) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 同步调用LLM API(回退方案,当没有aiohttp时使用)
|
|
|
+
|
|
|
+ 参数:
|
|
|
+ messages: 消息列表
|
|
|
+
|
|
|
+ 返回:
|
|
|
+ API响应结果
|
|
|
+ """
|
|
|
+ if not HAS_REQUESTS:
|
|
|
+ raise ImportError("需要安装 aiohttp 或 requests 库才能使用LLM API客户端")
|
|
|
+
|
|
|
+ headers = {
|
|
|
+ "Content-Type": "application/json",
|
|
|
+ "Authorization": f"Bearer {self.api_key}"
|
|
|
+ }
|
|
|
+
|
|
|
+ payload = {
|
|
|
+ "model": self.model_id,
|
|
|
+ "messages": messages,
|
|
|
+ "temperature": self.temperature,
|
|
|
+ "max_tokens": self.max_tokens,
|
|
|
+ "stream": self.stream
|
|
|
+ }
|
|
|
+
|
|
|
+ for attempt in range(self.max_retries):
|
|
|
+ try:
|
|
|
+ response = requests.post(
|
|
|
+ self.base_url,
|
|
|
+ json=payload,
|
|
|
+ headers=headers,
|
|
|
+ timeout=self.timeout
|
|
|
+ )
|
|
|
+ if response.status_code == 200:
|
|
|
+ return response.json()
|
|
|
+ else:
|
|
|
+ if attempt < self.max_retries - 1:
|
|
|
+ import time
|
|
|
+ time.sleep(1 * (attempt + 1))
|
|
|
+ continue
|
|
|
+ raise Exception(f"API调用失败,状态码: {response.status_code}, 错误: {response.text}")
|
|
|
+ except requests.Timeout:
|
|
|
+ if attempt < self.max_retries - 1:
|
|
|
+ import time
|
|
|
+ time.sleep(1 * (attempt + 1))
|
|
|
+ continue
|
|
|
+ raise Exception(f"API调用超时(超过{self.timeout}秒)")
|
|
|
+ except Exception as e:
|
|
|
+ if attempt < self.max_retries - 1:
|
|
|
+ import time
|
|
|
+ time.sleep(1 * (attempt + 1))
|
|
|
+ continue
|
|
|
+ raise
|
|
|
+
|
|
|
+ raise Exception("API调用失败,已达到最大重试次数")
|
|
|
+
|
|
|
+ async def _process_single_request(self, session: aiohttp.ClientSession, messages: List[Dict[str, str]]) -> Optional[Dict[str, Any]]:
|
|
|
+ """
|
|
|
+ 处理单个请求(包装异常处理)
|
|
|
+
|
|
|
+ 参数:
|
|
|
+ session: aiohttp会话
|
|
|
+ messages: 消息列表
|
|
|
+
|
|
|
+ 返回:
|
|
|
+ 解析后的JSON结果,如果失败则返回None
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ response = await self._call_api_async(session, messages)
|
|
|
+
|
|
|
+ # 提取响应内容
|
|
|
+ if "choices" in response and len(response["choices"]) > 0:
|
|
|
+ content = response["choices"][0].get("message", {}).get("content", "")
|
|
|
+
|
|
|
+ # 尝试解析JSON
|
|
|
+ try:
|
|
|
+ # 尝试提取JSON(可能在markdown代码块中)
|
|
|
+ if "```json" in content:
|
|
|
+ start = content.find("```json") + 7
|
|
|
+ end = content.find("```", start)
|
|
|
+ content = content[start:end].strip()
|
|
|
+ elif "```" in content:
|
|
|
+ start = content.find("```") + 3
|
|
|
+ end = content.find("```", start)
|
|
|
+ content = content[start:end].strip()
|
|
|
+
|
|
|
+ return json.loads(content)
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ # 如果不是JSON,返回原始内容
|
|
|
+ return {"raw_content": content}
|
|
|
+ else:
|
|
|
+ return None
|
|
|
+ except Exception as e:
|
|
|
+ print(f" LLM API调用错误: {e}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ async def batch_call_async(self, requests: List[List[Dict[str, str]]]) -> List[Optional[Dict[str, Any]]]:
|
|
|
+ """
|
|
|
+ 异步批量调用LLM API
|
|
|
+
|
|
|
+ 参数:
|
|
|
+ requests: 请求列表,每个请求是一个消息列表
|
|
|
+
|
|
|
+ 返回:
|
|
|
+ 结果列表,与输入请求一一对应
|
|
|
+ """
|
|
|
+ if not HAS_AIOHTTP:
|
|
|
+ # 回退到同步调用(在异步环境中)
|
|
|
+ if HAS_REQUESTS:
|
|
|
+ print(" 警告: 未安装aiohttp,在异步环境中使用同步调用(性能较差)")
|
|
|
+ results = []
|
|
|
+ for req in requests:
|
|
|
+ try:
|
|
|
+ response = self._call_api_sync(req)
|
|
|
+ if "choices" in response and len(response["choices"]) > 0:
|
|
|
+ content = response["choices"][0].get("message", {}).get("content", "")
|
|
|
+ try:
|
|
|
+ if "```json" in content:
|
|
|
+ start = content.find("```json") + 7
|
|
|
+ end = content.find("```", start)
|
|
|
+ content = content[start:end].strip()
|
|
|
+ elif "```" in content:
|
|
|
+ start = content.find("```") + 3
|
|
|
+ end = content.find("```", start)
|
|
|
+ content = content[start:end].strip()
|
|
|
+ results.append(json.loads(content))
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ results.append({"raw_content": content})
|
|
|
+ else:
|
|
|
+ results.append(None)
|
|
|
+ except Exception as e:
|
|
|
+ print(f" LLM API调用错误: {e}")
|
|
|
+ results.append(None)
|
|
|
+ return results
|
|
|
+ else:
|
|
|
+ raise ImportError("需要安装 aiohttp 或 requests 库才能使用LLM API客户端")
|
|
|
+
|
|
|
+ # 使用信号量限制并发数
|
|
|
+ semaphore = asyncio.Semaphore(self.concurrent_workers)
|
|
|
+
|
|
|
+ async def bounded_request(session, messages):
|
|
|
+ async with semaphore:
|
|
|
+ return await self._process_single_request(session, messages)
|
|
|
+
|
|
|
+ async with aiohttp.ClientSession() as session:
|
|
|
+ tasks = [bounded_request(session, req) for req in requests]
|
|
|
+ results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
|
+
|
|
|
+ # 处理异常结果
|
|
|
+ processed_results = []
|
|
|
+ for result in results:
|
|
|
+ if isinstance(result, Exception):
|
|
|
+ print(f" LLM API调用异常: {result}")
|
|
|
+ processed_results.append(None)
|
|
|
+ else:
|
|
|
+ processed_results.append(result)
|
|
|
+
|
|
|
+ return processed_results
|
|
|
+
|
|
|
+ def batch_call(self, requests: List[List[Dict[str, str]]]) -> List[Optional[Dict[str, Any]]]:
|
|
|
+ """
|
|
|
+ 同步批量调用LLM API(兼容接口)
|
|
|
+
|
|
|
+ 参数:
|
|
|
+ requests: 请求列表,每个请求是一个消息列表
|
|
|
+
|
|
|
+ 返回:
|
|
|
+ 结果列表,与输入请求一一对应
|
|
|
+
|
|
|
+ 注意: 此方法现在使用 workflow_manager.py 的全局事件循环,不再自行初始化事件循环
|
|
|
+ """
|
|
|
+ if HAS_AIOHTTP:
|
|
|
+ # 使用异步实现
|
|
|
+ # 注释掉异步初始化,使用 workflow_manager.py 的全局事件循环
|
|
|
+ # loop = asyncio.get_event_loop()
|
|
|
+ # if loop.is_running():
|
|
|
+ # # 如果事件循环已经在运行,创建新的事件循环
|
|
|
+ # import nest_asyncio
|
|
|
+ # try:
|
|
|
+ # nest_asyncio.apply()
|
|
|
+ # except ImportError:
|
|
|
+ # # 如果没有nest_asyncio,回退到同步调用
|
|
|
+ # return self._batch_call_sync_fallback(requests)
|
|
|
+ # return loop.run_until_complete(self.batch_call_async(requests))
|
|
|
+
|
|
|
+ # 使用 workflow_manager.py 的全局事件循环
|
|
|
+ try:
|
|
|
+ loop = asyncio.get_event_loop()
|
|
|
+ return loop.run_until_complete(self.batch_call_async(requests))
|
|
|
+ except RuntimeError:
|
|
|
+ # 如果没有事件循环,回退到同步调用
|
|
|
+ return self._batch_call_sync_fallback(requests)
|
|
|
+ else:
|
|
|
+ return self._batch_call_sync_fallback(requests)
|
|
|
+
|
|
|
+ def _batch_call_sync_fallback(self, requests: List[List[Dict[str, str]]]) -> List[Optional[Dict[str, Any]]]:
|
|
|
+ """
|
|
|
+ 同步批量调用回退方案
|
|
|
+ """
|
|
|
+ if not HAS_REQUESTS:
|
|
|
+ raise ImportError("需要安装 requests 库才能使用同步调用模式")
|
|
|
+
|
|
|
+ results = []
|
|
|
+ for req in requests:
|
|
|
+ try:
|
|
|
+ response = self._call_api_sync(req)
|
|
|
+ if "choices" in response and len(response["choices"]) > 0:
|
|
|
+ content = response["choices"][0].get("message", {}).get("content", "")
|
|
|
+ try:
|
|
|
+ if "```json" in content:
|
|
|
+ start = content.find("```json") + 7
|
|
|
+ end = content.find("```", start)
|
|
|
+ content = content[start:end].strip()
|
|
|
+ elif "```" in content:
|
|
|
+ start = content.find("```") + 3
|
|
|
+ end = content.find("```", start)
|
|
|
+ content = content[start:end].strip()
|
|
|
+ results.append(json.loads(content))
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ results.append({"raw_content": content})
|
|
|
+ else:
|
|
|
+ results.append(None)
|
|
|
+ except Exception as e:
|
|
|
+ print(f" LLM API调用错误: {e}")
|
|
|
+ results.append(None)
|
|
|
+ return results
|
|
|
+
|