| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310 |
- # -*- coding: utf-8 -*-
- """
- 上下文生成接口 - 通过统一模型调用框架 (generate_model_client)
- """
- import uuid
- import json
- import time
- import asyncio
- from typing import Optional, List, Dict, Any, AsyncGenerator
- from pydantic import BaseModel, Field
- from fastapi import APIRouter, HTTPException
- from fastapi.responses import StreamingResponse
- from foundation.observability.logger.loggering import write_logger as logger
- from foundation.infrastructure.tracing import TraceContext, auto_trace
- from foundation.ai.agent.generate.model_generate import generate_model_client
- from core.base.workflow_manager import workflow_manager
- from redis.asyncio import Redis as AsyncRedis
- # ==================== 1. 配置与路径初始化 ====================
- content_completion_router = APIRouter(prefix="/sgbx", tags=["施工方案编写"])
- CONTENT_COMPLETION_FUNCTION = "write_content_generate"
- # ==================== 2. 全局资源池 ====================
- GLOBAL_REDIS_CLIENT: Optional[AsyncRedis] = None
- async def init_global_resources():
- """初始化全局连接池"""
- global GLOBAL_REDIS_CLIENT
- if GLOBAL_REDIS_CLIENT is None:
- try:
- GLOBAL_REDIS_CLIENT = AsyncRedis(
- host='127.0.0.1', port=6379, password='123456', db=0,
- decode_responses=True, socket_connect_timeout=1,
- socket_keepalive=True, max_connections=50
- )
- asyncio.create_task(_background_ping())
- logger.info("✅ 全局 Redis 连接池已初始化")
- except Exception as e:
- logger.warning(f"⚠️ Redis 初始化失败: {e}")
- GLOBAL_REDIS_CLIENT = None
- async def _background_ping():
- if GLOBAL_REDIS_CLIENT:
- try: await GLOBAL_REDIS_CLIENT.ping()
- except: pass
- async def get_redis_client():
- if GLOBAL_REDIS_CLIENT is None:
- await init_global_resources()
- return GLOBAL_REDIS_CLIENT
- # ==================== 3. 流式调用 (通过统一模型框架) ====================
- async def call_custom_api_stream(
- prompt: str, system_prompt: str = "", max_tokens: int = 2000,
- temperature: float = 0.7, trace_id: str = ""
- ) -> AsyncGenerator[tuple[str, Optional[float]], None]:
- """流式调用 LLM,通过 generate_model_client 统一管理"""
- # 截断过长的 Prompt
- max_prompt_len = 10000
- if len(prompt) > max_prompt_len:
- prompt = prompt[-max_prompt_len:]
- logger.debug(f"[{trace_id}] Prompt 已截断至 {max_prompt_len} 字符")
- start_time = time.time()
- first_token_time: Optional[float] = None
- try:
- for chunk in generate_model_client.get_model_generate_stream(
- trace_id=trace_id,
- system_prompt=system_prompt,
- user_prompt=prompt,
- function_name=CONTENT_COMPLETION_FUNCTION
- ):
- if first_token_time is None:
- first_token_time = time.time() - start_time
- yield (chunk, first_token_time)
- except Exception as e:
- logger.error(f"[{trace_id}] 流式请求异常: {e}")
- raise
- # ==================== 6. 数据模型 ====================
- class CompletionConfig(BaseModel):
- section_path: str = Field(..., description="章节路径")
- current_content: str = Field(default="", description="当前已有内容")
- context_window: int = Field(default=2000, ge=500, le=5000)
- completion_mode: str = Field(default="continue", description="模式")
- target_length: int = Field(default=1000, ge=100, le=5000)
- include_references: bool = Field(default=True)
- style_match: bool = Field(default=True)
- hint_keywords: Optional[List[str]] = Field(default=None)
- class ProjectInfoSimple(BaseModel):
- project_name: str = Field(default="施工方案")
- construct_location: Optional[str] = Field(default=None)
- engineering_type: Optional[str] = Field(default=None)
- class ContentCompletionRequest(BaseModel):
- task_id: Optional[str] = Field(default=None)
- user_id: str = Field(...)
- project_info: Optional[ProjectInfoSimple] = Field(default=None)
- completion_config: CompletionConfig = Field(...)
- model_name: Optional[str] = Field(default=None)
- class Config: extra = "forbid"
- class ContentCompletionResponse(BaseModel):
- code: int
- message: str
- data: Optional[Dict[str, Any]] = None
- # ==================== 7. 业务逻辑辅助 ====================
- CONTENT_COMPLETION_SYSTEM_PROMPT = "你是一位专业的施工方案编写专家。请直接输出生成的内容文本,不要添加任何解释、标注或格式标记。要求生成的内容不超过100字。"
- def build_content_completion_prompt(project_info, section_path, section_title, current_content, completion_mode, target_length, include_references, style_match, hint_keywords, context_before="", context_after=""):
- parts = []
- parts.append(f"【项目】{project_info.get('project_name', '未知')}")
- parts.append(f"【章节】{section_title} ({section_path})")
- parts.append(f"【模式】{completion_mode} (目标:{target_length})")
-
- if context_before: parts.append(f"【前文】...{context_before[-500:]}")
- if current_content: parts.append(f"【当前】{current_content}")
- if context_after: parts.append(f"【后文】{context_after[:500]}...")
-
- parts.append("【指令】请根据上述信息继续生成专业内容,直接输出正文:")
- return "\n".join(parts)
- def extract_chunk_content(chunk: Any) -> str:
- if isinstance(chunk, str): return chunk
- if hasattr(chunk, 'content'): return str(chunk.content) if chunk.content else ""
- if isinstance(chunk, dict): return str(chunk.get('content', ''))
- return str(chunk)
- def validate_user_id(user_id: str):
- supported_users = {'user-001', 'user-002', 'user-003'}
- if user_id not in supported_users:
- raise HTTPException(status_code=403, detail={"code": "INVALID_USER", "message": "用户标识无效"})
- def validate_completion_config(config: CompletionConfig):
- if not config.section_path or not all(p.isdigit() for p in config.section_path.split(".")):
- raise HTTPException(status_code=400, detail={"code": "INVALID_PATH", "message": "章节路径格式错误"})
- def validate_request(request: ContentCompletionRequest):
- if not request.task_id and not request.project_info:
- raise HTTPException(status_code=400, detail={"code": "MISSING_INFO", "message": "缺少任务 ID 或项目信息"})
- def format_sse_event(event_type: str, data: str) -> str:
- return f"event: {event_type}\ndata: {data}\n\n"
- # ==================== 8. 核心流式生成逻辑 ====================
- async def generate_content_stream(callback_task_id, source_task_id, user_id, request, redis_client):
- async def is_cancelled() -> bool:
- if not redis_client: return False
- try: return await redis_client.exists(f"terminate:{callback_task_id}") > 0
- except: return False
- stream_start_time = time.time()
- first_token_latency: Optional[float] = None
- full_content_parts: List[str] = []
- chunk_count = 0
- try:
- yield format_sse_event("connected", json.dumps({
- "callback_task_id": callback_task_id, "status": "connected", "timestamp": int(time.time())
- }, ensure_ascii=False))
- project_info = request.project_info.dict() if request.project_info else {}
- section_title = f"章节 {request.completion_config.section_path}"
-
- user_prompt = build_content_completion_prompt(
- project_info=project_info,
- section_path=request.completion_config.section_path,
- section_title=section_title,
- current_content=request.completion_config.current_content,
- completion_mode=request.completion_config.completion_mode,
- target_length=request.completion_config.target_length,
- include_references=request.completion_config.include_references,
- style_match=request.completion_config.style_match,
- hint_keywords=request.completion_config.hint_keywords
- )
- yield format_sse_event("generating", json.dumps({
- "status": "generating",
- "message": "正在调用 LLM 模型 (write_content_generate)...",
- "timestamp": int(time.time())
- }, ensure_ascii=False))
- async for content, ftl in call_custom_api_stream(
- prompt=user_prompt,
- system_prompt=CONTENT_COMPLETION_SYSTEM_PROMPT,
- max_tokens=min(request.completion_config.target_length, 4000),
- temperature=0.7,
- trace_id=callback_task_id
- ):
- if await is_cancelled():
- yield format_sse_event("cancelled", json.dumps({"status": "cancelled"}, ensure_ascii=False))
- return
- if content:
- full_content_parts.append(content)
- chunk_count += 1
- if first_token_latency is None:
- first_token_latency = ftl if ftl is not None else (time.time() - stream_start_time)
- logger.info(f"[{callback_task_id}] ⚡ 首字延迟: {first_token_latency:.3f}s")
- yield format_sse_event("chunk", json.dumps({
- "chunk": content,
- "first_token_latency": round(first_token_latency, 3),
- "timestamp": int(time.time())
- }, ensure_ascii=False))
- # 完成统计
- total_duration = time.time() - stream_start_time
- full_content = "".join(full_content_parts)
-
- logger.info(f"[{callback_task_id}] ✅ 完成 | 首字: {first_token_latency:.3f}s | 总耗时: {total_duration:.3f}s | 字数: {len(full_content)}")
- yield format_sse_event("completed", json.dumps({
- "callback_task_id": callback_task_id,
- "status": "completed",
- "metrics": {
- "first_token_latency": round(first_token_latency, 3) if first_token_latency else 0.0,
- "total_duration": round(total_duration, 3),
- "char_count": len(full_content),
- "chunk_count": chunk_count,
- "model_used": CONTENT_COMPLETION_FUNCTION
- },
- "full_content": full_content,
- "timestamp": int(time.time())
- }, ensure_ascii=False))
- except Exception as e:
- logger.error(f"[{callback_task_id}] ❌ 异常: {str(e)}", exc_info=True)
- yield format_sse_event("error", json.dumps({"status": "error", "message": str(e)}, ensure_ascii=False))
- # ==================== 9. API 路由 ====================
- @content_completion_router.post("/content_completion")
- @auto_trace(generate_if_missing=True)
- async def content_completion(request: ContentCompletionRequest):
- callback_task_id = f"ctx_{uuid.uuid4().hex[:12]}"
- TraceContext.set_trace_id(callback_task_id)
-
- receive_time = time.time()
-
- try:
- validate_user_id(request.user_id)
- validate_completion_config(request.completion_config)
- validate_request(request)
-
- redis_client = await get_redis_client()
-
- logger.info(f"[{callback_task_id}] 请求接收 (预处理耗时: {(time.time()-receive_time)*1000:.1f}ms)")
- return StreamingResponse(
- generate_content_stream(callback_task_id, request.task_id, request.user_id, request, redis_client),
- media_type="text/event-stream",
- headers={
- "Cache-Control": "no-cache, no-store, must-revalidate",
- "Pragma": "no-cache",
- "Expires": "0",
- "Connection": "keep-alive",
- "X-Accel-Buffering": "no",
- "Content-Type": "text/event-stream; charset=utf-8",
- "Access-Control-Allow-Origin": "*"
- }
- )
-
- except HTTPException:
- raise
- except Exception as e:
- logger.error(f"[{callback_task_id}] 全局异常: {str(e)}")
- raise HTTPException(status_code=500, detail=str(e))
- @content_completion_router.get("/content_completion_health")
- async def health_check():
- return {
- "status": "healthy",
- "provider": "Shutian",
- "current_model": CONTENT_COMPLETION_FUNCTION,
- "api_url_prefix": "https://dashscope.aliyuncs.com/compatible-mode/v1"
- }
- @content_completion_router.get("/content_completion_modes", response_model=ContentCompletionResponse)
- async def get_modes():
- modes = [
- {"mode": "continue", "name": "续写"}, {"mode": "expand", "name": "扩写"},
- {"mode": "polish", "name": "润色"}, {"mode": "complete", "name": "补全"}
- ]
- return ContentCompletionResponse(code=200, message="success", data={"modes": modes})
- @content_completion_router.get("/content_completion_api_status", response_model=ContentCompletionResponse)
- async def get_api_status():
- return ContentCompletionResponse(
- code=200, message="success",
- data={
- "enabled": True,
- "provider": "Shutian",
- "model": CONTENT_COMPLETION_FUNCTION
- }
- )
|