# -*- coding: utf-8 -*- """ Shared API utilities for DashScope content completion endpoints. Extracted from outline_views.py and content_completion.py to eliminate ~200 lines of duplication. """ import uuid import json import time import asyncio import os from typing import Optional, List, Dict, Any, AsyncGenerator import aiohttp from fastapi.responses import StreamingResponse from foundation.observability.logger.loggering import write_logger as logger from foundation.infrastructure.config.config import config_handler from redis.asyncio import Redis as AsyncRedis # ==================== 全局资源池 ==================== GLOBAL_HTTP_SESSION: Optional[aiohttp.ClientSession] = None GLOBAL_REDIS_CLIENT: Optional[AsyncRedis] = None async def init_global_resources(): """初始化全局连接池""" global GLOBAL_HTTP_SESSION, GLOBAL_REDIS_CLIENT if GLOBAL_HTTP_SESSION is None or GLOBAL_HTTP_SESSION.closed: connector = aiohttp.TCPConnector(limit=100, limit_per_host=20, ttl_dns_cache=300, force_close=False) GLOBAL_HTTP_SESSION = aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=120, connect=10, sock_read=10), connector=connector, headers={"User-Agent": "FastAPI-DashScope-Optimized/2.0"} ) logger.info("✅ 全局 HTTP 连接池已初始化 (DashScope Ready)") if GLOBAL_REDIS_CLIENT is None: try: redis_password = config_handler.get('redis', 'REDIS_PASSWORD', '') or None GLOBAL_REDIS_CLIENT = AsyncRedis( host=config_handler.get('redis', 'REDIS_HOST', 'localhost'), port=int(config_handler.get('redis', 'REDIS_PORT', '6379')), password=redis_password, db=int(config_handler.get('redis', 'REDIS_DB', '0')), decode_responses=True, socket_connect_timeout=1, socket_keepalive=True, max_connections=50 ) asyncio.create_task(_background_ping()) logger.info("✅ 全局 Redis 连接池已初始化") except Exception as e: logger.warning(f"⚠️ Redis 初始化失败: {e}") GLOBAL_REDIS_CLIENT = None async def _background_ping(): if GLOBAL_REDIS_CLIENT: try: await GLOBAL_REDIS_CLIENT.ping() except Exception: pass async def get_http_session(): if GLOBAL_HTTP_SESSION is None or GLOBAL_HTTP_SESSION.closed: await init_global_resources() return GLOBAL_HTTP_SESSION async def get_redis_client(): if GLOBAL_REDIS_CLIENT is None: await init_global_resources() return GLOBAL_REDIS_CLIENT # ==================== CustomAPIConfig ==================== class CustomAPIConfig: DASHSCOPE_BASE_URL = config_handler.get( "custom_api", "DASHSCOPE_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1", ) DASHSCOPE_CHAT_URL = f"{DASHSCOPE_BASE_URL}/chat/completions" DASHSCOPE_API_KEY = "" DEFAULT_MODEL_NAME = "qwen3-30b-a3b-instruct-2507" @staticmethod def get_api_url() -> str: return CustomAPIConfig.DASHSCOPE_CHAT_URL @staticmethod def get_api_key() -> str: api_key = ( os.getenv("DASHSCOPE_API_KEY") or config_handler.get("custom_api", "DASHSCOPE_API_KEY", "") or CustomAPIConfig.DASHSCOPE_API_KEY ) return "" if api_key.startswith("${") else api_key @staticmethod def get_model_name() -> str: configured_model = config_handler.get("custom_api", "MODEL_NAME", "") return configured_model if configured_model else CustomAPIConfig.DEFAULT_MODEL_NAME @staticmethod def is_enabled() -> bool: return bool(CustomAPIConfig.get_api_key()) and bool(CustomAPIConfig.get_api_url()) # ==================== 极速流式调用 ==================== async def call_custom_api_stream( prompt: str, system_prompt: str = "", max_tokens: int = 2000, temperature: float = 0.7, trace_id: str = "", ) -> AsyncGenerator[tuple[str, Optional[float]], None]: api_url = CustomAPIConfig.get_api_url() model_name = CustomAPIConfig.get_model_name() api_key = CustomAPIConfig.get_api_key() logger.debug(f"[{trace_id}] 正在调用阿里云 DashScope: {model_name} @ {api_url}") max_prompt_len = 10000 if len(prompt) > max_prompt_len: prompt = prompt[-max_prompt_len:] logger.debug(f"[{trace_id}] Prompt 已截断至 {max_prompt_len} 字符") payload = { "model": model_name, "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": prompt}, ], "max_tokens": max_tokens, "temperature": temperature, "stream": True, "incremental_output": True, } headers = { "Content-Type": "application/json", "Authorization": f"Bearer {api_key}", } start_time = time.time() first_token_time: Optional[float] = None buffer = "" session = await get_http_session() try: async with session.post(api_url, json=payload, headers=headers, read_bufsize=1) as response: if response.status != 200: error_text = await response.text() logger.error(f"[{trace_id}] API 错误 {response.status}: {error_text}") raise Exception(f"API 错误 {response.status}: {error_text}") async for chunk in response.content.iter_any(): if not chunk: continue try: text = chunk.decode('utf-8', errors='ignore') if not text: continue buffer += text while '\n' in buffer: line, buffer = buffer.split('\n', 1) line = line.strip() if line.startswith('data: '): data = line[6:] if data == '[DONE]': return try: event_data = json.loads(data) if "error" in event_data: err_msg = event_data["error"].get("message", "Unknown Error") logger.error(f"[{trace_id}] 流式数据中包含错误: {err_msg}") continue choices = event_data.get("choices", []) if choices: delta = choices[0].get("delta", {}) content = delta.get("content", "") if content: if first_token_time is None: first_token_time = time.time() - start_time yield (content, first_token_time) except json.JSONDecodeError: continue except UnicodeDecodeError: continue except Exception as e: logger.error(f"[{trace_id}] API 流式请求异常: {e}") raise # ==================== 通用辅助函数 ==================== def format_sse_event(event_type: str, data: str) -> str: return f"event: {event_type}\ndata: {data}\n\n" def build_content_prompt( project_info, section_path, section_title, current_content, completion_mode, target_length, include_references, style_match, hint_keywords, context_before="", context_after="", ): parts = [] parts.append(f"【项目】{project_info.get('project_name', '未知')}") parts.append(f"【章节】{section_title} ({section_path})") parts.append(f"【模式】{completion_mode} (目标:{target_length})") if context_before: parts.append(f"【前文】...{context_before[-500:]}") if current_content: parts.append(f"【当前】{current_content}") if context_after: parts.append(f"【后文】{context_after[:500]}...") parts.append("【指令】请根据上述信息继续生成专业内容,直接输出正文:") return "\n".join(parts) def extract_chunk_content(chunk: Any) -> str: if isinstance(chunk, str): return chunk if hasattr(chunk, 'content'): return str(chunk.content) if chunk.content else "" if isinstance(chunk, dict): return str(chunk.get('content', '')) return str(chunk) def validate_user_id(user_id: str): supported_users = {'user-001', 'user-002', 'user-003'} if user_id not in supported_users: from fastapi import HTTPException raise HTTPException(status_code=403, detail={"code": "INVALID_USER", "message": "用户标识无效"}) def validate_completion_config(config): if not config.section_path or not all(p.isdigit() for p in config.section_path.split(".")): from fastapi import HTTPException raise HTTPException(status_code=400, detail={"code": "INVALID_PATH", "message": "章节路径格式错误"})