| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433 |
- # -*- coding: utf-8 -*-
- """
- 上下文生成接口 - 极速版 (DashScope Aliyun Optimized)
- 目标平台:阿里云 DashScope (兼容模式)
- API URL: https://dashscope.aliyuncs.com/compatible-mode/v1
- 模型:qwen3-30b-a3b-instruct-2507
- """
- import uuid
- import json
- import time
- import asyncio
- import aiohttp
- from typing import Optional, List, Dict, Any, AsyncGenerator
- from pydantic import BaseModel, Field
- from fastapi import APIRouter, HTTPException
- from fastapi.responses import StreamingResponse
- from foundation.observability.logger.loggering import write_logger as logger
- from foundation.infrastructure.tracing import TraceContext, auto_trace
- from foundation.infrastructure.config.config import config_handler
- from core.base.workflow_manager import WorkflowManager
- from redis.asyncio import Redis as AsyncRedis
- # ==================== 1. 配置与路径初始化 ====================
- content_completion_router = APIRouter(prefix="/sgbx", tags=["施工方案编写"])
- workflow_manager = WorkflowManager(max_concurrent_docs=3, max_concurrent_reviews=5)
- # ==================== 2. 全局资源池 (速度优化核心) ====================
- GLOBAL_HTTP_SESSION: Optional[aiohttp.ClientSession] = None
- GLOBAL_REDIS_CLIENT: Optional[AsyncRedis] = None
- async def init_global_resources():
- """初始化全局连接池"""
- global GLOBAL_HTTP_SESSION, GLOBAL_REDIS_CLIENT
-
- if GLOBAL_HTTP_SESSION is None or GLOBAL_HTTP_SESSION.closed:
- # 增加 DNS 缓存和连接复用,针对阿里云域名优化
- connector = aiohttp.TCPConnector(limit=100, limit_per_host=20, ttl_dns_cache=300, force_close=False)
- GLOBAL_HTTP_SESSION = aiohttp.ClientSession(
- timeout=aiohttp.ClientTimeout(total=120, connect=10, sock_read=10), # 连接超时稍长以防网络波动
- connector=connector,
- headers={"User-Agent": "FastAPI-DashScope-Optimized/2.0"}
- )
- logger.info("✅ 全局 HTTP 连接池已初始化 (DashScope Ready)")
- if GLOBAL_REDIS_CLIENT is None:
- try:
- GLOBAL_REDIS_CLIENT = AsyncRedis(
- host='127.0.0.1', port=6379, password='123456', db=0,
- decode_responses=True, socket_connect_timeout=1,
- socket_keepalive=True, max_connections=50
- )
- asyncio.create_task(_background_ping())
- logger.info("✅ 全局 Redis 连接池已初始化")
- except Exception as e:
- logger.warning(f"⚠️ Redis 初始化失败: {e}")
- GLOBAL_REDIS_CLIENT = None
- async def _background_ping():
- if GLOBAL_REDIS_CLIENT:
- try: await GLOBAL_REDIS_CLIENT.ping()
- except: pass
- async def get_http_session():
- if GLOBAL_HTTP_SESSION is None or GLOBAL_HTTP_SESSION.closed:
- await init_global_resources()
- return GLOBAL_HTTP_SESSION
- async def get_redis_client():
- if GLOBAL_REDIS_CLIENT is None:
- await init_global_resources()
- return GLOBAL_REDIS_CLIENT
- # ==================== 3. 文件操作工具 ====================
- # ==================== 4. 自定义 API 配置 (阿里云 DashScope) ====================
- class CustomAPIConfig:
- # 【关键修改】阿里云 DashScope 兼容模式地址
- # 注意:必须包含 /chat/completions 后缀
- DASHSCOPE_BASE_URL = "https://dashscope.aliyuncs.com/compatible-mode/v1"
- DASHSCOPE_CHAT_URL = f"{DASHSCOPE_BASE_URL}/chat/completions"
-
- # 【关键修改】您的 API Key
- DASHSCOPE_API_KEY = "sk-ae805c991b6a4a8da3a09351c34963a5"
-
- # 【关键修改】目标模型
- DEFAULT_MODEL_NAME = "qwen3-30b-a3b-instruct-2507"
-
- @staticmethod
- def get_api_url() -> str:
- # 优先使用硬编码的阿里云地址
- return CustomAPIConfig.DASHSCOPE_CHAT_URL
-
- @staticmethod
- def get_api_key() -> str:
- return CustomAPIConfig.DASHSCOPE_API_KEY
-
- @staticmethod
- def get_model_name() -> str:
- # 允许配置覆盖,否则使用默认
- configured_model = config_handler.get("custom_api", "MODEL_NAME", "")
- return configured_model if configured_model else CustomAPIConfig.DEFAULT_MODEL_NAME
-
- @staticmethod
- def is_enabled() -> bool:
- # 只要 Key 不为空即启用
- return bool(CustomAPIConfig.get_api_key()) and bool(CustomAPIConfig.get_api_url())
- # ==================== 5. 极速流式调用 (核心优化) ====================
- async def call_custom_api_stream(
- prompt: str, system_prompt: str = "", max_tokens: int = 2000,
- temperature: float = 0.7, trace_id: str = ""
- ) -> AsyncGenerator[tuple[str, Optional[float]], None]:
-
- api_url = CustomAPIConfig.get_api_url()
- model_name = CustomAPIConfig.get_model_name()
- api_key = CustomAPIConfig.get_api_key()
-
- logger.debug(f"[{trace_id}] 正在调用阿里云 DashScope: {model_name} @ {api_url}")
- # 截断过长的 Prompt (阿里云对输入长度有限制,且为了速度)
- max_prompt_len = 10000
- if len(prompt) > max_prompt_len:
- prompt = prompt[-max_prompt_len:]
- logger.debug(f"[{trace_id}] Prompt 已截断至 {max_prompt_len} 字符")
- payload = {
- "model": model_name,
- "messages": [
- {"role": "system", "content": system_prompt},
- {"role": "user", "content": prompt}
- ],
- "max_tokens": max_tokens,
- "temperature": temperature,
- "stream": True,
- "incremental_output": True # 阿里云兼容模式可能支持此参数,优化流式体验
- }
-
- headers = {
- "Content-Type": "application/json",
- "Authorization": f"Bearer {api_key}"
- }
-
- start_time = time.time()
- first_token_time: Optional[float] = None
- buffer = ""
-
- session = await get_http_session()
-
- try:
- # 阿里云 HTTPS 连接,保持 read_bufsize=1 以获取最快首字
- async with session.post(api_url, json=payload, headers=headers, read_bufsize=1) as response:
- if response.status != 200:
- error_text = await response.text()
- logger.error(f"[{trace_id}] API 错误 {response.status}: {error_text}")
- raise Exception(f"API 错误 {response.status}: {error_text}")
-
- async for chunk in response.content.iter_any():
- if not chunk: continue
- try:
- text = chunk.decode('utf-8', errors='ignore')
- if not text: continue
- buffer += text
-
- while '\n' in buffer:
- line, buffer = buffer.split('\n', 1)
- line = line.strip()
-
- if line.startswith('data: '):
- data = line[6:]
- if data == '[DONE]':
- return
-
- try:
- event_data = json.loads(data)
- # 处理阿里云可能的错误格式
- if "error" in event_data:
- err_msg = event_data["error"].get("message", "Unknown Error")
- logger.error(f"[{trace_id}] 流式数据中包含错误: {err_msg}")
- continue
- choices = event_data.get("choices", [])
- if choices:
- delta = choices[0].get("delta", {})
- content = delta.get("content", "")
-
- if content:
- if first_token_time is None:
- first_token_time = time.time() - start_time
- yield (content, first_token_time)
- except json.JSONDecodeError:
- continue
- except UnicodeDecodeError:
- continue
- except Exception as e:
- logger.error(f"[{trace_id}] API 流式请求异常: {e}")
- raise
- # ==================== 6. 数据模型 ====================
- class CompletionConfig(BaseModel):
- section_path: str = Field(..., description="章节路径")
- current_content: str = Field(default="", description="当前已有内容")
- context_window: int = Field(default=2000, ge=500, le=5000)
- completion_mode: str = Field(default="continue", description="模式")
- target_length: int = Field(default=1000, ge=100, le=5000)
- include_references: bool = Field(default=True)
- style_match: bool = Field(default=True)
- hint_keywords: Optional[List[str]] = Field(default=None)
- class ProjectInfoSimple(BaseModel):
- project_name: str = Field(default="施工方案")
- construct_location: Optional[str] = Field(default=None)
- engineering_type: Optional[str] = Field(default=None)
- class ContentCompletionRequest(BaseModel):
- task_id: Optional[str] = Field(default=None)
- user_id: str = Field(...)
- project_info: Optional[ProjectInfoSimple] = Field(default=None)
- completion_config: CompletionConfig = Field(...)
- model_name: Optional[str] = Field(default=None)
- class Config: extra = "forbid"
- class ContentCompletionResponse(BaseModel):
- code: int
- message: str
- data: Optional[Dict[str, Any]] = None
- # ==================== 7. 业务逻辑辅助 ====================
- CONTENT_COMPLETION_SYSTEM_PROMPT = "你是一位专业的施工方案编写专家。请直接输出生成的内容文本,不要添加任何解释、标注或格式标记。要求生成的内容不超过100字。"
- def build_content_completion_prompt(project_info, section_path, section_title, current_content, completion_mode, target_length, include_references, style_match, hint_keywords, context_before="", context_after=""):
- parts = []
- parts.append(f"【项目】{project_info.get('project_name', '未知')}")
- parts.append(f"【章节】{section_title} ({section_path})")
- parts.append(f"【模式】{completion_mode} (目标:{target_length})")
-
- if context_before: parts.append(f"【前文】...{context_before[-500:]}")
- if current_content: parts.append(f"【当前】{current_content}")
- if context_after: parts.append(f"【后文】{context_after[:500]}...")
-
- parts.append("【指令】请根据上述信息继续生成专业内容,直接输出正文:")
- return "\n".join(parts)
- def extract_chunk_content(chunk: Any) -> str:
- if isinstance(chunk, str): return chunk
- if hasattr(chunk, 'content'): return str(chunk.content) if chunk.content else ""
- if isinstance(chunk, dict): return str(chunk.get('content', ''))
- return str(chunk)
- def validate_user_id(user_id: str):
- supported_users = {'user-001', 'user-002', 'user-003'}
- if user_id not in supported_users:
- raise HTTPException(status_code=403, detail={"code": "INVALID_USER", "message": "用户标识无效"})
- def validate_completion_config(config: CompletionConfig):
- if not config.section_path or not all(p.isdigit() for p in config.section_path.split(".")):
- raise HTTPException(status_code=400, detail={"code": "INVALID_PATH", "message": "章节路径格式错误"})
- def validate_request(request: ContentCompletionRequest):
- if not request.task_id and not request.project_info:
- raise HTTPException(status_code=400, detail={"code": "MISSING_INFO", "message": "缺少任务 ID 或项目信息"})
- def format_sse_event(event_type: str, data: str) -> str:
- return f"event: {event_type}\ndata: {data}\n\n"
- # ==================== 8. 核心流式生成逻辑 ====================
- async def generate_content_stream(callback_task_id, source_task_id, user_id, request, redis_client):
- async def is_cancelled() -> bool:
- if not redis_client: return False
- try: return await redis_client.exists(f"terminate:{callback_task_id}") > 0
- except: return False
- stream_start_time = time.time()
- first_token_latency: Optional[float] = None
- full_content_parts: List[str] = []
- chunk_count = 0
- try:
- yield format_sse_event("connected", json.dumps({
- "callback_task_id": callback_task_id, "status": "connected", "timestamp": int(time.time())
- }, ensure_ascii=False))
- project_info = request.project_info.dict() if request.project_info else {}
- section_title = f"章节 {request.completion_config.section_path}"
-
- user_prompt = build_content_completion_prompt(
- project_info=project_info,
- section_path=request.completion_config.section_path,
- section_title=section_title,
- current_content=request.completion_config.current_content,
- completion_mode=request.completion_config.completion_mode,
- target_length=request.completion_config.target_length,
- include_references=request.completion_config.include_references,
- style_match=request.completion_config.style_match,
- hint_keywords=request.completion_config.hint_keywords
- )
- yield format_sse_event("generating", json.dumps({
- "status": "generating",
- "message": f"正在调用阿里云 Qwen3 ({CustomAPIConfig.get_model_name()})...",
- "timestamp": int(time.time())
- }, ensure_ascii=False))
- # 执行生成
- if CustomAPIConfig.is_enabled():
- logger.info(f"[{callback_task_id}] 使用阿里云 DashScope API (模型:{CustomAPIConfig.get_model_name()})")
- async for content, ftl in call_custom_api_stream(
- prompt=user_prompt,
- system_prompt=CONTENT_COMPLETION_SYSTEM_PROMPT,
- max_tokens=min(request.completion_config.target_length, 4000),
- temperature=0.7,
- trace_id=callback_task_id
- ):
- if await is_cancelled():
- yield format_sse_event("cancelled", json.dumps({"status": "cancelled"}, ensure_ascii=False))
- return
-
- if content:
- full_content_parts.append(content)
- chunk_count += 1
-
- if first_token_latency is None:
- first_token_latency = ftl if ftl is not None else (time.time() - stream_start_time)
- logger.info(f"[{callback_task_id}] ⚡ 首字延迟: {first_token_latency:.3f}s (Model: {CustomAPIConfig.get_model_name()})")
-
- yield format_sse_event("chunk", json.dumps({
- "chunk": content,
- "first_token_latency": round(first_token_latency, 3),
- "timestamp": int(time.time())
- }, ensure_ascii=False))
- else:
- # 备用逻辑 (理论上不会触发,因为 Key 已硬编码)
- logger.warning(f"[{callback_task_id}] API 配置失效,回退到默认模型 (不应发生)")
- raise Exception("API 配置未生效,请检查 CustomAPIConfig")
- # 完成统计
- total_duration = time.time() - stream_start_time
- full_content = "".join(full_content_parts)
-
- logger.info(f"[{callback_task_id}] ✅ 完成 | 首字: {first_token_latency:.3f}s | 总耗时: {total_duration:.3f}s | 字数: {len(full_content)}")
- yield format_sse_event("completed", json.dumps({
- "callback_task_id": callback_task_id,
- "status": "completed",
- "metrics": {
- "first_token_latency": round(first_token_latency, 3) if first_token_latency else 0.0,
- "total_duration": round(total_duration, 3),
- "char_count": len(full_content),
- "chunk_count": chunk_count,
- "model_used": CustomAPIConfig.get_model_name()
- },
- "full_content": full_content,
- "timestamp": int(time.time())
- }, ensure_ascii=False))
- except Exception as e:
- logger.error(f"[{callback_task_id}] ❌ 异常: {str(e)}", exc_info=True)
- yield format_sse_event("error", json.dumps({"status": "error", "message": str(e)}, ensure_ascii=False))
- # ==================== 9. API 路由 ====================
- @content_completion_router.post("/content_completion")
- @auto_trace(generate_if_missing=True)
- async def content_completion(request: ContentCompletionRequest):
- callback_task_id = f"ctx_{uuid.uuid4().hex[:12]}"
- TraceContext.set_trace_id(callback_task_id)
-
- receive_time = time.time()
-
- try:
- validate_user_id(request.user_id)
- validate_completion_config(request.completion_config)
- validate_request(request)
-
- redis_client = await get_redis_client()
-
- logger.info(f"[{callback_task_id}] 请求接收 (预处理耗时: {(time.time()-receive_time)*1000:.1f}ms)")
- return StreamingResponse(
- generate_content_stream(callback_task_id, request.task_id, request.user_id, request, redis_client),
- media_type="text/event-stream",
- headers={
- "Cache-Control": "no-cache, no-store, must-revalidate",
- "Pragma": "no-cache",
- "Expires": "0",
- "Connection": "keep-alive",
- "X-Accel-Buffering": "no",
- "Content-Type": "text/event-stream; charset=utf-8",
- "Access-Control-Allow-Origin": "*"
- }
- )
-
- except HTTPException:
- raise
- except Exception as e:
- logger.error(f"[{callback_task_id}] 全局异常: {str(e)}")
- raise HTTPException(status_code=500, detail=str(e))
- @content_completion_router.get("/content_completion_health")
- async def health_check():
- return {
- "status": "healthy",
- "provider": "Aliyun DashScope",
- "current_model": CustomAPIConfig.get_model_name(),
- "api_url_prefix": "https://dashscope.aliyuncs.com/compatible-mode/v1"
- }
- @content_completion_router.get("/content_completion_modes", response_model=ContentCompletionResponse)
- async def get_modes():
- modes = [
- {"mode": "continue", "name": "续写"}, {"mode": "expand", "name": "扩写"},
- {"mode": "polish", "name": "润色"}, {"mode": "complete", "name": "补全"}
- ]
- return ContentCompletionResponse(code=200, message="success", data={"modes": modes})
- @content_completion_router.get("/content_completion_api_status", response_model=ContentCompletionResponse)
- async def get_api_status():
- enabled = CustomAPIConfig.is_enabled()
- return ContentCompletionResponse(
- code=200, message="success",
- data={
- "enabled": enabled,
- "provider": "Aliyun DashScope",
- "model": CustomAPIConfig.get_model_name()
- }
- )
|