shared_api_utils.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260
  1. # -*- coding: utf-8 -*-
  2. """
  3. Shared API utilities for DashScope content completion endpoints.
  4. Extracted from outline_views.py and content_completion.py to eliminate ~200 lines of duplication.
  5. """
  6. import uuid
  7. import json
  8. import time
  9. import asyncio
  10. import os
  11. from typing import Optional, List, Dict, Any, AsyncGenerator
  12. import aiohttp
  13. from fastapi.responses import StreamingResponse
  14. from foundation.observability.logger.loggering import write_logger as logger
  15. from foundation.infrastructure.config.config import config_handler
  16. from redis.asyncio import Redis as AsyncRedis
  17. # ==================== 全局资源池 ====================
  18. GLOBAL_HTTP_SESSION: Optional[aiohttp.ClientSession] = None
  19. GLOBAL_REDIS_CLIENT: Optional[AsyncRedis] = None
  20. async def init_global_resources():
  21. """初始化全局连接池"""
  22. global GLOBAL_HTTP_SESSION, GLOBAL_REDIS_CLIENT
  23. if GLOBAL_HTTP_SESSION is None or GLOBAL_HTTP_SESSION.closed:
  24. connector = aiohttp.TCPConnector(limit=100, limit_per_host=20, ttl_dns_cache=300, force_close=False)
  25. GLOBAL_HTTP_SESSION = aiohttp.ClientSession(
  26. timeout=aiohttp.ClientTimeout(total=120, connect=10, sock_read=10),
  27. connector=connector,
  28. headers={"User-Agent": "FastAPI-DashScope-Optimized/2.0"}
  29. )
  30. logger.info("✅ 全局 HTTP 连接池已初始化 (DashScope Ready)")
  31. if GLOBAL_REDIS_CLIENT is None:
  32. try:
  33. redis_password = config_handler.get('redis', 'REDIS_PASSWORD', '') or None
  34. GLOBAL_REDIS_CLIENT = AsyncRedis(
  35. host=config_handler.get('redis', 'REDIS_HOST', 'localhost'),
  36. port=int(config_handler.get('redis', 'REDIS_PORT', '6379')),
  37. password=redis_password,
  38. db=int(config_handler.get('redis', 'REDIS_DB', '0')),
  39. decode_responses=True, socket_connect_timeout=1,
  40. socket_keepalive=True, max_connections=50
  41. )
  42. asyncio.create_task(_background_ping())
  43. logger.info("✅ 全局 Redis 连接池已初始化")
  44. except Exception as e:
  45. logger.warning(f"⚠️ Redis 初始化失败: {e}")
  46. GLOBAL_REDIS_CLIENT = None
  47. async def _background_ping():
  48. if GLOBAL_REDIS_CLIENT:
  49. try:
  50. await GLOBAL_REDIS_CLIENT.ping()
  51. except Exception:
  52. pass
  53. async def get_http_session():
  54. if GLOBAL_HTTP_SESSION is None or GLOBAL_HTTP_SESSION.closed:
  55. await init_global_resources()
  56. return GLOBAL_HTTP_SESSION
  57. async def get_redis_client():
  58. if GLOBAL_REDIS_CLIENT is None:
  59. await init_global_resources()
  60. return GLOBAL_REDIS_CLIENT
  61. # ==================== CustomAPIConfig ====================
  62. class CustomAPIConfig:
  63. DASHSCOPE_BASE_URL = config_handler.get(
  64. "custom_api",
  65. "DASHSCOPE_BASE_URL",
  66. "https://dashscope.aliyuncs.com/compatible-mode/v1",
  67. )
  68. DASHSCOPE_CHAT_URL = f"{DASHSCOPE_BASE_URL}/chat/completions"
  69. DASHSCOPE_API_KEY = ""
  70. DEFAULT_MODEL_NAME = "qwen3-30b-a3b-instruct-2507"
  71. @staticmethod
  72. def get_api_url() -> str:
  73. return CustomAPIConfig.DASHSCOPE_CHAT_URL
  74. @staticmethod
  75. def get_api_key() -> str:
  76. api_key = (
  77. os.getenv("DASHSCOPE_API_KEY")
  78. or config_handler.get("custom_api", "DASHSCOPE_API_KEY", "")
  79. or CustomAPIConfig.DASHSCOPE_API_KEY
  80. )
  81. return "" if api_key.startswith("${") else api_key
  82. @staticmethod
  83. def get_model_name() -> str:
  84. configured_model = config_handler.get("custom_api", "MODEL_NAME", "")
  85. return configured_model if configured_model else CustomAPIConfig.DEFAULT_MODEL_NAME
  86. @staticmethod
  87. def is_enabled() -> bool:
  88. return bool(CustomAPIConfig.get_api_key()) and bool(CustomAPIConfig.get_api_url())
  89. # ==================== 极速流式调用 ====================
  90. async def call_custom_api_stream(
  91. prompt: str,
  92. system_prompt: str = "",
  93. max_tokens: int = 2000,
  94. temperature: float = 0.7,
  95. trace_id: str = "",
  96. ) -> AsyncGenerator[tuple[str, Optional[float]], None]:
  97. api_url = CustomAPIConfig.get_api_url()
  98. model_name = CustomAPIConfig.get_model_name()
  99. api_key = CustomAPIConfig.get_api_key()
  100. logger.debug(f"[{trace_id}] 正在调用阿里云 DashScope: {model_name} @ {api_url}")
  101. max_prompt_len = 10000
  102. if len(prompt) > max_prompt_len:
  103. prompt = prompt[-max_prompt_len:]
  104. logger.debug(f"[{trace_id}] Prompt 已截断至 {max_prompt_len} 字符")
  105. payload = {
  106. "model": model_name,
  107. "messages": [
  108. {"role": "system", "content": system_prompt},
  109. {"role": "user", "content": prompt},
  110. ],
  111. "max_tokens": max_tokens,
  112. "temperature": temperature,
  113. "stream": True,
  114. "incremental_output": True,
  115. }
  116. headers = {
  117. "Content-Type": "application/json",
  118. "Authorization": f"Bearer {api_key}",
  119. }
  120. start_time = time.time()
  121. first_token_time: Optional[float] = None
  122. buffer = ""
  123. session = await get_http_session()
  124. try:
  125. async with session.post(api_url, json=payload, headers=headers, read_bufsize=1) as response:
  126. if response.status != 200:
  127. error_text = await response.text()
  128. logger.error(f"[{trace_id}] API 错误 {response.status}: {error_text}")
  129. raise Exception(f"API 错误 {response.status}: {error_text}")
  130. async for chunk in response.content.iter_any():
  131. if not chunk:
  132. continue
  133. try:
  134. text = chunk.decode('utf-8', errors='ignore')
  135. if not text:
  136. continue
  137. buffer += text
  138. while '\n' in buffer:
  139. line, buffer = buffer.split('\n', 1)
  140. line = line.strip()
  141. if line.startswith('data: '):
  142. data = line[6:]
  143. if data == '[DONE]':
  144. return
  145. try:
  146. event_data = json.loads(data)
  147. if "error" in event_data:
  148. err_msg = event_data["error"].get("message", "Unknown Error")
  149. logger.error(f"[{trace_id}] 流式数据中包含错误: {err_msg}")
  150. continue
  151. choices = event_data.get("choices", [])
  152. if choices:
  153. delta = choices[0].get("delta", {})
  154. content = delta.get("content", "")
  155. if content:
  156. if first_token_time is None:
  157. first_token_time = time.time() - start_time
  158. yield (content, first_token_time)
  159. except json.JSONDecodeError:
  160. continue
  161. except UnicodeDecodeError:
  162. continue
  163. except Exception as e:
  164. logger.error(f"[{trace_id}] API 流式请求异常: {e}")
  165. raise
  166. # ==================== 通用辅助函数 ====================
  167. def format_sse_event(event_type: str, data: str) -> str:
  168. return f"event: {event_type}\ndata: {data}\n\n"
  169. def build_content_prompt(
  170. project_info,
  171. section_path,
  172. section_title,
  173. current_content,
  174. completion_mode,
  175. target_length,
  176. include_references,
  177. style_match,
  178. hint_keywords,
  179. context_before="",
  180. context_after="",
  181. ):
  182. parts = []
  183. parts.append(f"【项目】{project_info.get('project_name', '未知')}")
  184. parts.append(f"【章节】{section_title} ({section_path})")
  185. parts.append(f"【模式】{completion_mode} (目标:{target_length})")
  186. if context_before:
  187. parts.append(f"【前文】...{context_before[-500:]}")
  188. if current_content:
  189. parts.append(f"【当前】{current_content}")
  190. if context_after:
  191. parts.append(f"【后文】{context_after[:500]}...")
  192. parts.append("【指令】请根据上述信息继续生成专业内容,直接输出正文:")
  193. return "\n".join(parts)
  194. def extract_chunk_content(chunk: Any) -> str:
  195. if isinstance(chunk, str):
  196. return chunk
  197. if hasattr(chunk, 'content'):
  198. return str(chunk.content) if chunk.content else ""
  199. if isinstance(chunk, dict):
  200. return str(chunk.get('content', ''))
  201. return str(chunk)
  202. def validate_user_id(user_id: str):
  203. supported_users = {'user-001', 'user-002', 'user-003'}
  204. if user_id not in supported_users:
  205. from fastapi import HTTPException
  206. raise HTTPException(status_code=403, detail={"code": "INVALID_USER", "message": "用户标识无效"})
  207. def validate_completion_config(config):
  208. if not config.section_path or not all(p.isdigit() for p in config.section_path.split(".")):
  209. from fastapi import HTTPException
  210. raise HTTPException(status_code=400, detail={"code": "INVALID_PATH", "message": "章节路径格式错误"})