| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260 |
- # -*- coding: utf-8 -*-
- """
- Shared API utilities for DashScope content completion endpoints.
- Extracted from outline_views.py and content_completion.py to eliminate ~200 lines of duplication.
- """
- import uuid
- import json
- import time
- import asyncio
- import os
- from typing import Optional, List, Dict, Any, AsyncGenerator
- import aiohttp
- from fastapi.responses import StreamingResponse
- from foundation.observability.logger.loggering import write_logger as logger
- from foundation.infrastructure.config.config import config_handler
- from redis.asyncio import Redis as AsyncRedis
- # ==================== 全局资源池 ====================
- GLOBAL_HTTP_SESSION: Optional[aiohttp.ClientSession] = None
- GLOBAL_REDIS_CLIENT: Optional[AsyncRedis] = None
- async def init_global_resources():
- """初始化全局连接池"""
- global GLOBAL_HTTP_SESSION, GLOBAL_REDIS_CLIENT
- if GLOBAL_HTTP_SESSION is None or GLOBAL_HTTP_SESSION.closed:
- connector = aiohttp.TCPConnector(limit=100, limit_per_host=20, ttl_dns_cache=300, force_close=False)
- GLOBAL_HTTP_SESSION = aiohttp.ClientSession(
- timeout=aiohttp.ClientTimeout(total=120, connect=10, sock_read=10),
- connector=connector,
- headers={"User-Agent": "FastAPI-DashScope-Optimized/2.0"}
- )
- logger.info("✅ 全局 HTTP 连接池已初始化 (DashScope Ready)")
- if GLOBAL_REDIS_CLIENT is None:
- try:
- redis_password = config_handler.get('redis', 'REDIS_PASSWORD', '') or None
- GLOBAL_REDIS_CLIENT = AsyncRedis(
- host=config_handler.get('redis', 'REDIS_HOST', 'localhost'),
- port=int(config_handler.get('redis', 'REDIS_PORT', '6379')),
- password=redis_password,
- db=int(config_handler.get('redis', 'REDIS_DB', '0')),
- decode_responses=True, socket_connect_timeout=1,
- socket_keepalive=True, max_connections=50
- )
- asyncio.create_task(_background_ping())
- logger.info("✅ 全局 Redis 连接池已初始化")
- except Exception as e:
- logger.warning(f"⚠️ Redis 初始化失败: {e}")
- GLOBAL_REDIS_CLIENT = None
- async def _background_ping():
- if GLOBAL_REDIS_CLIENT:
- try:
- await GLOBAL_REDIS_CLIENT.ping()
- except Exception:
- pass
- async def get_http_session():
- if GLOBAL_HTTP_SESSION is None or GLOBAL_HTTP_SESSION.closed:
- await init_global_resources()
- return GLOBAL_HTTP_SESSION
- async def get_redis_client():
- if GLOBAL_REDIS_CLIENT is None:
- await init_global_resources()
- return GLOBAL_REDIS_CLIENT
- # ==================== CustomAPIConfig ====================
- class CustomAPIConfig:
- DASHSCOPE_BASE_URL = config_handler.get(
- "custom_api",
- "DASHSCOPE_BASE_URL",
- "https://dashscope.aliyuncs.com/compatible-mode/v1",
- )
- DASHSCOPE_CHAT_URL = f"{DASHSCOPE_BASE_URL}/chat/completions"
- DASHSCOPE_API_KEY = ""
- DEFAULT_MODEL_NAME = "qwen3-30b-a3b-instruct-2507"
- @staticmethod
- def get_api_url() -> str:
- return CustomAPIConfig.DASHSCOPE_CHAT_URL
- @staticmethod
- def get_api_key() -> str:
- api_key = (
- os.getenv("DASHSCOPE_API_KEY")
- or config_handler.get("custom_api", "DASHSCOPE_API_KEY", "")
- or CustomAPIConfig.DASHSCOPE_API_KEY
- )
- return "" if api_key.startswith("${") else api_key
- @staticmethod
- def get_model_name() -> str:
- configured_model = config_handler.get("custom_api", "MODEL_NAME", "")
- return configured_model if configured_model else CustomAPIConfig.DEFAULT_MODEL_NAME
- @staticmethod
- def is_enabled() -> bool:
- return bool(CustomAPIConfig.get_api_key()) and bool(CustomAPIConfig.get_api_url())
- # ==================== 极速流式调用 ====================
- async def call_custom_api_stream(
- prompt: str,
- system_prompt: str = "",
- max_tokens: int = 2000,
- temperature: float = 0.7,
- trace_id: str = "",
- ) -> AsyncGenerator[tuple[str, Optional[float]], None]:
- api_url = CustomAPIConfig.get_api_url()
- model_name = CustomAPIConfig.get_model_name()
- api_key = CustomAPIConfig.get_api_key()
- logger.debug(f"[{trace_id}] 正在调用阿里云 DashScope: {model_name} @ {api_url}")
- max_prompt_len = 10000
- if len(prompt) > max_prompt_len:
- prompt = prompt[-max_prompt_len:]
- logger.debug(f"[{trace_id}] Prompt 已截断至 {max_prompt_len} 字符")
- payload = {
- "model": model_name,
- "messages": [
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": prompt},
- ],
- "max_tokens": max_tokens,
- "temperature": temperature,
- "stream": True,
- "incremental_output": True,
- }
- headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {api_key}",
- }
- start_time = time.time()
- first_token_time: Optional[float] = None
- buffer = ""
- session = await get_http_session()
- try:
- async with session.post(api_url, json=payload, headers=headers, read_bufsize=1) as response:
- if response.status != 200:
- error_text = await response.text()
- logger.error(f"[{trace_id}] API 错误 {response.status}: {error_text}")
- raise Exception(f"API 错误 {response.status}: {error_text}")
- async for chunk in response.content.iter_any():
- if not chunk:
- continue
- try:
- text = chunk.decode('utf-8', errors='ignore')
- if not text:
- continue
- buffer += text
- while '\n' in buffer:
- line, buffer = buffer.split('\n', 1)
- line = line.strip()
- if line.startswith('data: '):
- data = line[6:]
- if data == '[DONE]':
- return
- try:
- event_data = json.loads(data)
- if "error" in event_data:
- err_msg = event_data["error"].get("message", "Unknown Error")
- logger.error(f"[{trace_id}] 流式数据中包含错误: {err_msg}")
- continue
- choices = event_data.get("choices", [])
- if choices:
- delta = choices[0].get("delta", {})
- content = delta.get("content", "")
- if content:
- if first_token_time is None:
- first_token_time = time.time() - start_time
- yield (content, first_token_time)
- except json.JSONDecodeError:
- continue
- except UnicodeDecodeError:
- continue
- except Exception as e:
- logger.error(f"[{trace_id}] API 流式请求异常: {e}")
- raise
- # ==================== 通用辅助函数 ====================
- def format_sse_event(event_type: str, data: str) -> str:
- return f"event: {event_type}\ndata: {data}\n\n"
- def build_content_prompt(
- project_info,
- section_path,
- section_title,
- current_content,
- completion_mode,
- target_length,
- include_references,
- style_match,
- hint_keywords,
- context_before="",
- context_after="",
- ):
- parts = []
- parts.append(f"【项目】{project_info.get('project_name', '未知')}")
- parts.append(f"【章节】{section_title} ({section_path})")
- parts.append(f"【模式】{completion_mode} (目标:{target_length})")
- if context_before:
- parts.append(f"【前文】...{context_before[-500:]}")
- if current_content:
- parts.append(f"【当前】{current_content}")
- if context_after:
- parts.append(f"【后文】{context_after[:500]}...")
- parts.append("【指令】请根据上述信息继续生成专业内容,直接输出正文:")
- return "\n".join(parts)
- def extract_chunk_content(chunk: Any) -> str:
- if isinstance(chunk, str):
- return chunk
- if hasattr(chunk, 'content'):
- return str(chunk.content) if chunk.content else ""
- if isinstance(chunk, dict):
- return str(chunk.get('content', ''))
- return str(chunk)
- def validate_user_id(user_id: str):
- supported_users = {'user-001', 'user-002', 'user-003'}
- if user_id not in supported_users:
- from fastapi import HTTPException
- raise HTTPException(status_code=403, detail={"code": "INVALID_USER", "message": "用户标识无效"})
- def validate_completion_config(config):
- if not config.section_path or not all(p.isdigit() for p in config.section_path.split(".")):
- from fastapi import HTTPException
- raise HTTPException(status_code=400, detail={"code": "INVALID_PATH", "message": "章节路径格式错误"})
|