content_completion.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. # -*- coding: utf-8 -*-
  2. """
  3. 上下文生成接口 - 通过统一模型调用框架 (generate_model_client)
  4. """
  5. import uuid
  6. import json
  7. import time
  8. import asyncio
  9. from typing import Optional, List, Dict, Any, AsyncGenerator
  10. from pydantic import BaseModel, Field
  11. from fastapi import APIRouter, HTTPException
  12. from fastapi.responses import StreamingResponse
  13. from foundation.observability.logger.loggering import write_logger as logger
  14. from foundation.infrastructure.tracing import TraceContext, auto_trace
  15. from foundation.ai.agent.generate.model_generate import generate_model_client
  16. from core.base.workflow_manager import workflow_manager
  17. from redis.asyncio import Redis as AsyncRedis
  18. # ==================== 1. 配置与路径初始化 ====================
  19. content_completion_router = APIRouter(prefix="/sgbx", tags=["施工方案编写"])
  20. CONTENT_COMPLETION_FUNCTION = "write_content_generate"
  21. # ==================== 2. 全局资源池 ====================
  22. GLOBAL_REDIS_CLIENT: Optional[AsyncRedis] = None
  23. async def init_global_resources():
  24. """初始化全局连接池"""
  25. global GLOBAL_REDIS_CLIENT
  26. if GLOBAL_REDIS_CLIENT is None:
  27. try:
  28. GLOBAL_REDIS_CLIENT = AsyncRedis(
  29. host='127.0.0.1', port=6379, password='123456', db=0,
  30. decode_responses=True, socket_connect_timeout=1,
  31. socket_keepalive=True, max_connections=50
  32. )
  33. asyncio.create_task(_background_ping())
  34. logger.info("✅ 全局 Redis 连接池已初始化")
  35. except Exception as e:
  36. logger.warning(f"⚠️ Redis 初始化失败: {e}")
  37. GLOBAL_REDIS_CLIENT = None
  38. async def _background_ping():
  39. if GLOBAL_REDIS_CLIENT:
  40. try: await GLOBAL_REDIS_CLIENT.ping()
  41. except: pass
  42. async def get_redis_client():
  43. if GLOBAL_REDIS_CLIENT is None:
  44. await init_global_resources()
  45. return GLOBAL_REDIS_CLIENT
  46. # ==================== 3. 流式调用 (通过统一模型框架) ====================
  47. async def call_custom_api_stream(
  48. prompt: str, system_prompt: str = "", max_tokens: int = 2000,
  49. temperature: float = 0.7, trace_id: str = ""
  50. ) -> AsyncGenerator[tuple[str, Optional[float]], None]:
  51. """流式调用 LLM,通过 generate_model_client 统一管理"""
  52. # 截断过长的 Prompt
  53. max_prompt_len = 10000
  54. if len(prompt) > max_prompt_len:
  55. prompt = prompt[-max_prompt_len:]
  56. logger.debug(f"[{trace_id}] Prompt 已截断至 {max_prompt_len} 字符")
  57. start_time = time.time()
  58. first_token_time: Optional[float] = None
  59. try:
  60. for chunk in generate_model_client.get_model_generate_stream(
  61. trace_id=trace_id,
  62. system_prompt=system_prompt,
  63. user_prompt=prompt,
  64. function_name=CONTENT_COMPLETION_FUNCTION
  65. ):
  66. if first_token_time is None:
  67. first_token_time = time.time() - start_time
  68. yield (chunk, first_token_time)
  69. except Exception as e:
  70. logger.error(f"[{trace_id}] 流式请求异常: {e}")
  71. raise
  72. # ==================== 6. 数据模型 ====================
  73. class CompletionConfig(BaseModel):
  74. section_path: str = Field(..., description="章节路径")
  75. current_content: str = Field(default="", description="当前已有内容")
  76. context_window: int = Field(default=2000, ge=500, le=5000)
  77. completion_mode: str = Field(default="continue", description="模式")
  78. target_length: int = Field(default=1000, ge=100, le=5000)
  79. include_references: bool = Field(default=True)
  80. style_match: bool = Field(default=True)
  81. hint_keywords: Optional[List[str]] = Field(default=None)
  82. class ProjectInfoSimple(BaseModel):
  83. project_name: str = Field(default="施工方案")
  84. construct_location: Optional[str] = Field(default=None)
  85. engineering_type: Optional[str] = Field(default=None)
  86. class ContentCompletionRequest(BaseModel):
  87. task_id: Optional[str] = Field(default=None)
  88. user_id: str = Field(...)
  89. project_info: Optional[ProjectInfoSimple] = Field(default=None)
  90. completion_config: CompletionConfig = Field(...)
  91. model_name: Optional[str] = Field(default=None)
  92. class Config: extra = "forbid"
  93. class ContentCompletionResponse(BaseModel):
  94. code: int
  95. message: str
  96. data: Optional[Dict[str, Any]] = None
  97. # ==================== 7. 业务逻辑辅助 ====================
  98. CONTENT_COMPLETION_SYSTEM_PROMPT = "你是一位专业的施工方案编写专家。请直接输出生成的内容文本,不要添加任何解释、标注或格式标记。要求生成的内容不超过100字。"
  99. def build_content_completion_prompt(project_info, section_path, section_title, current_content, completion_mode, target_length, include_references, style_match, hint_keywords, context_before="", context_after=""):
  100. parts = []
  101. parts.append(f"【项目】{project_info.get('project_name', '未知')}")
  102. parts.append(f"【章节】{section_title} ({section_path})")
  103. parts.append(f"【模式】{completion_mode} (目标:{target_length})")
  104. if context_before: parts.append(f"【前文】...{context_before[-500:]}")
  105. if current_content: parts.append(f"【当前】{current_content}")
  106. if context_after: parts.append(f"【后文】{context_after[:500]}...")
  107. parts.append("【指令】请根据上述信息继续生成专业内容,直接输出正文:")
  108. return "\n".join(parts)
  109. def extract_chunk_content(chunk: Any) -> str:
  110. if isinstance(chunk, str): return chunk
  111. if hasattr(chunk, 'content'): return str(chunk.content) if chunk.content else ""
  112. if isinstance(chunk, dict): return str(chunk.get('content', ''))
  113. return str(chunk)
  114. def validate_user_id(user_id: str):
  115. supported_users = {'user-001', 'user-002', 'user-003'}
  116. if user_id not in supported_users:
  117. raise HTTPException(status_code=403, detail={"code": "INVALID_USER", "message": "用户标识无效"})
  118. def validate_completion_config(config: CompletionConfig):
  119. if not config.section_path or not all(p.isdigit() for p in config.section_path.split(".")):
  120. raise HTTPException(status_code=400, detail={"code": "INVALID_PATH", "message": "章节路径格式错误"})
  121. def validate_request(request: ContentCompletionRequest):
  122. if not request.task_id and not request.project_info:
  123. raise HTTPException(status_code=400, detail={"code": "MISSING_INFO", "message": "缺少任务 ID 或项目信息"})
  124. def format_sse_event(event_type: str, data: str) -> str:
  125. return f"event: {event_type}\ndata: {data}\n\n"
  126. # ==================== 8. 核心流式生成逻辑 ====================
  127. async def generate_content_stream(callback_task_id, source_task_id, user_id, request, redis_client):
  128. async def is_cancelled() -> bool:
  129. if not redis_client: return False
  130. try: return await redis_client.exists(f"terminate:{callback_task_id}") > 0
  131. except: return False
  132. stream_start_time = time.time()
  133. first_token_latency: Optional[float] = None
  134. full_content_parts: List[str] = []
  135. chunk_count = 0
  136. try:
  137. yield format_sse_event("connected", json.dumps({
  138. "callback_task_id": callback_task_id, "status": "connected", "timestamp": int(time.time())
  139. }, ensure_ascii=False))
  140. project_info = request.project_info.dict() if request.project_info else {}
  141. section_title = f"章节 {request.completion_config.section_path}"
  142. user_prompt = build_content_completion_prompt(
  143. project_info=project_info,
  144. section_path=request.completion_config.section_path,
  145. section_title=section_title,
  146. current_content=request.completion_config.current_content,
  147. completion_mode=request.completion_config.completion_mode,
  148. target_length=request.completion_config.target_length,
  149. include_references=request.completion_config.include_references,
  150. style_match=request.completion_config.style_match,
  151. hint_keywords=request.completion_config.hint_keywords
  152. )
  153. yield format_sse_event("generating", json.dumps({
  154. "status": "generating",
  155. "message": "正在调用 LLM 模型 (write_content_generate)...",
  156. "timestamp": int(time.time())
  157. }, ensure_ascii=False))
  158. async for content, ftl in call_custom_api_stream(
  159. prompt=user_prompt,
  160. system_prompt=CONTENT_COMPLETION_SYSTEM_PROMPT,
  161. max_tokens=min(request.completion_config.target_length, 4000),
  162. temperature=0.7,
  163. trace_id=callback_task_id
  164. ):
  165. if await is_cancelled():
  166. yield format_sse_event("cancelled", json.dumps({"status": "cancelled"}, ensure_ascii=False))
  167. return
  168. if content:
  169. full_content_parts.append(content)
  170. chunk_count += 1
  171. if first_token_latency is None:
  172. first_token_latency = ftl if ftl is not None else (time.time() - stream_start_time)
  173. logger.info(f"[{callback_task_id}] ⚡ 首字延迟: {first_token_latency:.3f}s")
  174. yield format_sse_event("chunk", json.dumps({
  175. "chunk": content,
  176. "first_token_latency": round(first_token_latency, 3),
  177. "timestamp": int(time.time())
  178. }, ensure_ascii=False))
  179. # 完成统计
  180. total_duration = time.time() - stream_start_time
  181. full_content = "".join(full_content_parts)
  182. logger.info(f"[{callback_task_id}] ✅ 完成 | 首字: {first_token_latency:.3f}s | 总耗时: {total_duration:.3f}s | 字数: {len(full_content)}")
  183. yield format_sse_event("completed", json.dumps({
  184. "callback_task_id": callback_task_id,
  185. "status": "completed",
  186. "metrics": {
  187. "first_token_latency": round(first_token_latency, 3) if first_token_latency else 0.0,
  188. "total_duration": round(total_duration, 3),
  189. "char_count": len(full_content),
  190. "chunk_count": chunk_count,
  191. "model_used": CONTENT_COMPLETION_FUNCTION
  192. },
  193. "full_content": full_content,
  194. "timestamp": int(time.time())
  195. }, ensure_ascii=False))
  196. except Exception as e:
  197. logger.error(f"[{callback_task_id}] ❌ 异常: {str(e)}", exc_info=True)
  198. yield format_sse_event("error", json.dumps({"status": "error", "message": str(e)}, ensure_ascii=False))
  199. # ==================== 9. API 路由 ====================
  200. @content_completion_router.post("/content_completion")
  201. @auto_trace(generate_if_missing=True)
  202. async def content_completion(request: ContentCompletionRequest):
  203. callback_task_id = f"ctx_{uuid.uuid4().hex[:12]}"
  204. TraceContext.set_trace_id(callback_task_id)
  205. receive_time = time.time()
  206. try:
  207. validate_user_id(request.user_id)
  208. validate_completion_config(request.completion_config)
  209. validate_request(request)
  210. redis_client = await get_redis_client()
  211. logger.info(f"[{callback_task_id}] 请求接收 (预处理耗时: {(time.time()-receive_time)*1000:.1f}ms)")
  212. return StreamingResponse(
  213. generate_content_stream(callback_task_id, request.task_id, request.user_id, request, redis_client),
  214. media_type="text/event-stream",
  215. headers={
  216. "Cache-Control": "no-cache, no-store, must-revalidate",
  217. "Pragma": "no-cache",
  218. "Expires": "0",
  219. "Connection": "keep-alive",
  220. "X-Accel-Buffering": "no",
  221. "Content-Type": "text/event-stream; charset=utf-8",
  222. "Access-Control-Allow-Origin": "*"
  223. }
  224. )
  225. except HTTPException:
  226. raise
  227. except Exception as e:
  228. logger.error(f"[{callback_task_id}] 全局异常: {str(e)}")
  229. raise HTTPException(status_code=500, detail=str(e))
  230. @content_completion_router.get("/content_completion_health")
  231. async def health_check():
  232. return {
  233. "status": "healthy",
  234. "provider": "Shutian",
  235. "current_model": CONTENT_COMPLETION_FUNCTION,
  236. "api_url_prefix": "https://dashscope.aliyuncs.com/compatible-mode/v1"
  237. }
  238. @content_completion_router.get("/content_completion_modes", response_model=ContentCompletionResponse)
  239. async def get_modes():
  240. modes = [
  241. {"mode": "continue", "name": "续写"}, {"mode": "expand", "name": "扩写"},
  242. {"mode": "polish", "name": "润色"}, {"mode": "complete", "name": "补全"}
  243. ]
  244. return ContentCompletionResponse(code=200, message="success", data={"modes": modes})
  245. @content_completion_router.get("/content_completion_api_status", response_model=ContentCompletionResponse)
  246. async def get_api_status():
  247. return ContentCompletionResponse(
  248. code=200, message="success",
  249. data={
  250. "enabled": True,
  251. "provider": "Shutian",
  252. "model": CONTENT_COMPLETION_FUNCTION
  253. }
  254. )