content_completion.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464
  1. # -*- coding: utf-8 -*-
  2. """
  3. 上下文生成接口 - 极速版 (Shutian Optimized)
  4. 目标平台:蜀天算力 Qwen3.5-122B-A10B
  5. API: 蜀天算力 (通过统一配置管理)
  6. 模型:Qwen3.5-122B-A10B
  7. """
  8. import uuid
  9. import json
  10. import time
  11. import asyncio
  12. import aiohttp
  13. from typing import Optional, List, Dict, Any, AsyncGenerator
  14. from pydantic import BaseModel, Field
  15. from fastapi import APIRouter, HTTPException
  16. from fastapi.responses import StreamingResponse
  17. from foundation.observability.logger.loggering import write_logger as logger
  18. from foundation.infrastructure.tracing import TraceContext, auto_trace
  19. from foundation.infrastructure.config.config import config_handler
  20. from core.base.workflow_manager import workflow_manager
  21. from redis.asyncio import Redis as AsyncRedis
  22. # ==================== 1. 配置与路径初始化 ====================
  23. content_completion_router = APIRouter(prefix="/sgbx", tags=["施工方案编写"])
  24. # ==================== 2. 全局资源池 (速度优化核心) ====================
  25. GLOBAL_HTTP_SESSION: Optional[aiohttp.ClientSession] = None
  26. GLOBAL_REDIS_CLIENT: Optional[AsyncRedis] = None
  27. async def init_global_resources():
  28. """初始化全局连接池"""
  29. global GLOBAL_HTTP_SESSION, GLOBAL_REDIS_CLIENT
  30. if GLOBAL_HTTP_SESSION is None or GLOBAL_HTTP_SESSION.closed:
  31. # 增加 DNS 缓存和连接复用,针对蜀天算力域名优化
  32. connector = aiohttp.TCPConnector(limit=100, limit_per_host=20, ttl_dns_cache=300, force_close=False)
  33. GLOBAL_HTTP_SESSION = aiohttp.ClientSession(
  34. timeout=aiohttp.ClientTimeout(total=120, connect=10, sock_read=10), # 连接超时稍长以防网络波动
  35. connector=connector,
  36. headers={"User-Agent": "FastAPI-Shutian-Optimized/2.0"}
  37. )
  38. logger.info("✅ 全局 HTTP 连接池已初始化 (Shutian Ready)")
  39. if GLOBAL_REDIS_CLIENT is None:
  40. try:
  41. GLOBAL_REDIS_CLIENT = AsyncRedis(
  42. host='127.0.0.1', port=6379, password='123456', db=0,
  43. decode_responses=True, socket_connect_timeout=1,
  44. socket_keepalive=True, max_connections=50
  45. )
  46. asyncio.create_task(_background_ping())
  47. logger.info("✅ 全局 Redis 连接池已初始化")
  48. except Exception as e:
  49. logger.warning(f"⚠️ Redis 初始化失败: {e}")
  50. GLOBAL_REDIS_CLIENT = None
  51. async def _background_ping():
  52. if GLOBAL_REDIS_CLIENT:
  53. try: await GLOBAL_REDIS_CLIENT.ping()
  54. except: pass
  55. async def get_http_session():
  56. if GLOBAL_HTTP_SESSION is None or GLOBAL_HTTP_SESSION.closed:
  57. await init_global_resources()
  58. return GLOBAL_HTTP_SESSION
  59. async def get_redis_client():
  60. if GLOBAL_REDIS_CLIENT is None:
  61. await init_global_resources()
  62. return GLOBAL_REDIS_CLIENT
  63. # ==================== 3. 文件操作工具 ====================
  64. # ==================== 4. 自定义 API 配置 (蜀天算力 Qwen3.5-122B) ====================
  65. class CustomAPIConfig:
  66. # model_setting.yaml 中的功能名称
  67. FUNCTION_NAME = "write_content_generate"
  68. # 兜底默认值(蜀天 Qwen3.5-122B-A10B)
  69. SHUTIAN_SERVER_URL_DEFAULT = "http://183.220.37.46:25423/v1"
  70. SHUTIAN_API_KEY_DEFAULT = "lq123456"
  71. DEFAULT_MODEL_NAME = "/model/Qwen3.5-122B-A10B"
  72. @staticmethod
  73. def _resolve_from_model_handler():
  74. """通过 model_handler 统一解析模型配置(url, api_key, model_id)"""
  75. try:
  76. from foundation.ai.models.model_handler import model_handler
  77. llm = model_handler.get_model_by_function(CustomAPIConfig.FUNCTION_NAME)
  78. url = getattr(llm, 'base_url', None) or getattr(llm, 'openai_api_base', '')
  79. url = str(url) if url else ''
  80. model_id = getattr(llm, 'model_name', None) or getattr(llm, 'model', '')
  81. model_id = str(model_id) if model_id else ''
  82. api_key = getattr(llm, 'openai_api_key', None)
  83. if api_key:
  84. api_key = api_key.get_secret_value() if hasattr(api_key, 'get_secret_value') else str(api_key)
  85. else:
  86. api_key = ''
  87. if url and api_key:
  88. return url, api_key, model_id
  89. except Exception:
  90. pass
  91. return None, None, None
  92. @staticmethod
  93. def get_api_url() -> str:
  94. configured_url = config_handler.get("custom_api", "API_URL", "")
  95. if configured_url:
  96. return configured_url
  97. url, _, _ = CustomAPIConfig._resolve_from_model_handler()
  98. if url:
  99. return url
  100. return config_handler.get("shutian", "SHUTIAN_122B_SERVER_URL", CustomAPIConfig.SHUTIAN_SERVER_URL_DEFAULT)
  101. @staticmethod
  102. def get_api_key() -> str:
  103. configured_key = config_handler.get("custom_api", "API_KEY", "")
  104. if configured_key:
  105. return configured_key
  106. _, api_key, _ = CustomAPIConfig._resolve_from_model_handler()
  107. if api_key:
  108. return api_key
  109. return config_handler.get("shutian", "SHUTIAN_122B_API_KEY", CustomAPIConfig.SHUTIAN_API_KEY_DEFAULT)
  110. @staticmethod
  111. def get_model_name() -> str:
  112. configured_model = config_handler.get("custom_api", "MODEL_NAME", "")
  113. if configured_model:
  114. return configured_model
  115. _, _, model_id = CustomAPIConfig._resolve_from_model_handler()
  116. if model_id:
  117. return model_id
  118. return config_handler.get("shutian", "SHUTIAN_122B_MODEL_ID", CustomAPIConfig.DEFAULT_MODEL_NAME)
  119. @staticmethod
  120. def is_enabled() -> bool:
  121. return bool(CustomAPIConfig.get_api_key()) and bool(CustomAPIConfig.get_api_url())
  122. # ==================== 5. 极速流式调用 (核心优化) ====================
  123. async def call_custom_api_stream(
  124. prompt: str, system_prompt: str = "", max_tokens: int = 2000,
  125. temperature: float = 0.7, trace_id: str = ""
  126. ) -> AsyncGenerator[tuple[str, Optional[float]], None]:
  127. api_url = CustomAPIConfig.get_api_url()
  128. model_name = CustomAPIConfig.get_model_name()
  129. api_key = CustomAPIConfig.get_api_key()
  130. logger.debug(f"[{trace_id}] 正在调用蜀天算力: {model_name} @ {api_url}")
  131. # 截断过长的 Prompt (服务端对输入长度有限制,且为了速度)
  132. max_prompt_len = 10000
  133. if len(prompt) > max_prompt_len:
  134. prompt = prompt[-max_prompt_len:]
  135. logger.debug(f"[{trace_id}] Prompt 已截断至 {max_prompt_len} 字符")
  136. payload = {
  137. "model": model_name,
  138. "messages": [
  139. {"role": "system", "content": system_prompt},
  140. {"role": "user", "content": prompt}
  141. ],
  142. "max_tokens": max_tokens,
  143. "temperature": temperature,
  144. "stream": True,
  145. "incremental_output": True # 蜀天算力兼容模式支持此参数,优化流式体验
  146. }
  147. headers = {
  148. "Content-Type": "application/json",
  149. "Authorization": f"Bearer {api_key}"
  150. }
  151. start_time = time.time()
  152. first_token_time: Optional[float] = None
  153. buffer = ""
  154. session = await get_http_session()
  155. try:
  156. # 蜀天算力 HTTP 连接,保持 read_bufsize=1 以获取最快首字
  157. async with session.post(api_url, json=payload, headers=headers, read_bufsize=1) as response:
  158. if response.status != 200:
  159. error_text = await response.text()
  160. logger.error(f"[{trace_id}] API 错误 {response.status}: {error_text}")
  161. raise Exception(f"API 错误 {response.status}: {error_text}")
  162. async for chunk in response.content.iter_any():
  163. if not chunk: continue
  164. try:
  165. text = chunk.decode('utf-8', errors='ignore')
  166. if not text: continue
  167. buffer += text
  168. while '\n' in buffer:
  169. line, buffer = buffer.split('\n', 1)
  170. line = line.strip()
  171. if line.startswith('data: '):
  172. data = line[6:]
  173. if data == '[DONE]':
  174. return
  175. try:
  176. event_data = json.loads(data)
  177. # 处理服务端可能的错误格式
  178. if "error" in event_data:
  179. err_msg = event_data["error"].get("message", "Unknown Error")
  180. logger.error(f"[{trace_id}] 流式数据中包含错误: {err_msg}")
  181. continue
  182. choices = event_data.get("choices", [])
  183. if choices:
  184. delta = choices[0].get("delta", {})
  185. content = delta.get("content", "")
  186. if content:
  187. if first_token_time is None:
  188. first_token_time = time.time() - start_time
  189. yield (content, first_token_time)
  190. except json.JSONDecodeError:
  191. continue
  192. except UnicodeDecodeError:
  193. continue
  194. except Exception as e:
  195. logger.error(f"[{trace_id}] API 流式请求异常: {e}")
  196. raise
  197. # ==================== 6. 数据模型 ====================
  198. class CompletionConfig(BaseModel):
  199. section_path: str = Field(..., description="章节路径")
  200. current_content: str = Field(default="", description="当前已有内容")
  201. context_window: int = Field(default=2000, ge=500, le=5000)
  202. completion_mode: str = Field(default="continue", description="模式")
  203. target_length: int = Field(default=1000, ge=100, le=5000)
  204. include_references: bool = Field(default=True)
  205. style_match: bool = Field(default=True)
  206. hint_keywords: Optional[List[str]] = Field(default=None)
  207. class ProjectInfoSimple(BaseModel):
  208. project_name: str = Field(default="施工方案")
  209. construct_location: Optional[str] = Field(default=None)
  210. engineering_type: Optional[str] = Field(default=None)
  211. class ContentCompletionRequest(BaseModel):
  212. task_id: Optional[str] = Field(default=None)
  213. user_id: str = Field(...)
  214. project_info: Optional[ProjectInfoSimple] = Field(default=None)
  215. completion_config: CompletionConfig = Field(...)
  216. model_name: Optional[str] = Field(default=None)
  217. class Config: extra = "forbid"
  218. class ContentCompletionResponse(BaseModel):
  219. code: int
  220. message: str
  221. data: Optional[Dict[str, Any]] = None
  222. # ==================== 7. 业务逻辑辅助 ====================
  223. CONTENT_COMPLETION_SYSTEM_PROMPT = "你是一位专业的施工方案编写专家。请直接输出生成的内容文本,不要添加任何解释、标注或格式标记。要求生成的内容不超过100字。"
  224. def build_content_completion_prompt(project_info, section_path, section_title, current_content, completion_mode, target_length, include_references, style_match, hint_keywords, context_before="", context_after=""):
  225. parts = []
  226. parts.append(f"【项目】{project_info.get('project_name', '未知')}")
  227. parts.append(f"【章节】{section_title} ({section_path})")
  228. parts.append(f"【模式】{completion_mode} (目标:{target_length})")
  229. if context_before: parts.append(f"【前文】...{context_before[-500:]}")
  230. if current_content: parts.append(f"【当前】{current_content}")
  231. if context_after: parts.append(f"【后文】{context_after[:500]}...")
  232. parts.append("【指令】请根据上述信息继续生成专业内容,直接输出正文:")
  233. return "\n".join(parts)
  234. def extract_chunk_content(chunk: Any) -> str:
  235. if isinstance(chunk, str): return chunk
  236. if hasattr(chunk, 'content'): return str(chunk.content) if chunk.content else ""
  237. if isinstance(chunk, dict): return str(chunk.get('content', ''))
  238. return str(chunk)
  239. def validate_user_id(user_id: str):
  240. supported_users = {'user-001', 'user-002', 'user-003'}
  241. if user_id not in supported_users:
  242. raise HTTPException(status_code=403, detail={"code": "INVALID_USER", "message": "用户标识无效"})
  243. def validate_completion_config(config: CompletionConfig):
  244. if not config.section_path or not all(p.isdigit() for p in config.section_path.split(".")):
  245. raise HTTPException(status_code=400, detail={"code": "INVALID_PATH", "message": "章节路径格式错误"})
  246. def validate_request(request: ContentCompletionRequest):
  247. if not request.task_id and not request.project_info:
  248. raise HTTPException(status_code=400, detail={"code": "MISSING_INFO", "message": "缺少任务 ID 或项目信息"})
  249. def format_sse_event(event_type: str, data: str) -> str:
  250. return f"event: {event_type}\ndata: {data}\n\n"
  251. # ==================== 8. 核心流式生成逻辑 ====================
  252. async def generate_content_stream(callback_task_id, source_task_id, user_id, request, redis_client):
  253. async def is_cancelled() -> bool:
  254. if not redis_client: return False
  255. try: return await redis_client.exists(f"terminate:{callback_task_id}") > 0
  256. except: return False
  257. stream_start_time = time.time()
  258. first_token_latency: Optional[float] = None
  259. full_content_parts: List[str] = []
  260. chunk_count = 0
  261. try:
  262. yield format_sse_event("connected", json.dumps({
  263. "callback_task_id": callback_task_id, "status": "connected", "timestamp": int(time.time())
  264. }, ensure_ascii=False))
  265. project_info = request.project_info.dict() if request.project_info else {}
  266. section_title = f"章节 {request.completion_config.section_path}"
  267. user_prompt = build_content_completion_prompt(
  268. project_info=project_info,
  269. section_path=request.completion_config.section_path,
  270. section_title=section_title,
  271. current_content=request.completion_config.current_content,
  272. completion_mode=request.completion_config.completion_mode,
  273. target_length=request.completion_config.target_length,
  274. include_references=request.completion_config.include_references,
  275. style_match=request.completion_config.style_match,
  276. hint_keywords=request.completion_config.hint_keywords
  277. )
  278. yield format_sse_event("generating", json.dumps({
  279. "status": "generating",
  280. "message": f"正在调用蜀天 Qwen3.5-122B ({CustomAPIConfig.get_model_name()})...",
  281. "timestamp": int(time.time())
  282. }, ensure_ascii=False))
  283. # 执行生成
  284. if CustomAPIConfig.is_enabled():
  285. logger.info(f"[{callback_task_id}] 使用蜀天算力 API (模型:{CustomAPIConfig.get_model_name()})")
  286. async for content, ftl in call_custom_api_stream(
  287. prompt=user_prompt,
  288. system_prompt=CONTENT_COMPLETION_SYSTEM_PROMPT,
  289. max_tokens=min(request.completion_config.target_length, 4000),
  290. temperature=0.7,
  291. trace_id=callback_task_id
  292. ):
  293. if await is_cancelled():
  294. yield format_sse_event("cancelled", json.dumps({"status": "cancelled"}, ensure_ascii=False))
  295. return
  296. if content:
  297. full_content_parts.append(content)
  298. chunk_count += 1
  299. if first_token_latency is None:
  300. first_token_latency = ftl if ftl is not None else (time.time() - stream_start_time)
  301. logger.info(f"[{callback_task_id}] ⚡ 首字延迟: {first_token_latency:.3f}s (Model: {CustomAPIConfig.get_model_name()})")
  302. yield format_sse_event("chunk", json.dumps({
  303. "chunk": content,
  304. "first_token_latency": round(first_token_latency, 3),
  305. "timestamp": int(time.time())
  306. }, ensure_ascii=False))
  307. else:
  308. # 备用逻辑 (理论上不会触发,因为 Key 已硬编码)
  309. logger.warning(f"[{callback_task_id}] API 配置失效,回退到默认模型 (不应发生)")
  310. raise Exception("API 配置未生效,请检查 CustomAPIConfig")
  311. # 完成统计
  312. total_duration = time.time() - stream_start_time
  313. full_content = "".join(full_content_parts)
  314. logger.info(f"[{callback_task_id}] ✅ 完成 | 首字: {first_token_latency:.3f}s | 总耗时: {total_duration:.3f}s | 字数: {len(full_content)}")
  315. yield format_sse_event("completed", json.dumps({
  316. "callback_task_id": callback_task_id,
  317. "status": "completed",
  318. "metrics": {
  319. "first_token_latency": round(first_token_latency, 3) if first_token_latency else 0.0,
  320. "total_duration": round(total_duration, 3),
  321. "char_count": len(full_content),
  322. "chunk_count": chunk_count,
  323. "model_used": CustomAPIConfig.get_model_name()
  324. },
  325. "full_content": full_content,
  326. "timestamp": int(time.time())
  327. }, ensure_ascii=False))
  328. except Exception as e:
  329. logger.error(f"[{callback_task_id}] ❌ 异常: {str(e)}", exc_info=True)
  330. yield format_sse_event("error", json.dumps({"status": "error", "message": str(e)}, ensure_ascii=False))
  331. # ==================== 9. API 路由 ====================
  332. @content_completion_router.post("/content_completion")
  333. @auto_trace(generate_if_missing=True)
  334. async def content_completion(request: ContentCompletionRequest):
  335. callback_task_id = f"ctx_{uuid.uuid4().hex[:12]}"
  336. TraceContext.set_trace_id(callback_task_id)
  337. receive_time = time.time()
  338. try:
  339. validate_user_id(request.user_id)
  340. validate_completion_config(request.completion_config)
  341. validate_request(request)
  342. redis_client = await get_redis_client()
  343. logger.info(f"[{callback_task_id}] 请求接收 (预处理耗时: {(time.time()-receive_time)*1000:.1f}ms)")
  344. return StreamingResponse(
  345. generate_content_stream(callback_task_id, request.task_id, request.user_id, request, redis_client),
  346. media_type="text/event-stream",
  347. headers={
  348. "Cache-Control": "no-cache, no-store, must-revalidate",
  349. "Pragma": "no-cache",
  350. "Expires": "0",
  351. "Connection": "keep-alive",
  352. "X-Accel-Buffering": "no",
  353. "Content-Type": "text/event-stream; charset=utf-8",
  354. "Access-Control-Allow-Origin": "*"
  355. }
  356. )
  357. except HTTPException:
  358. raise
  359. except Exception as e:
  360. logger.error(f"[{callback_task_id}] 全局异常: {str(e)}")
  361. raise HTTPException(status_code=500, detail=str(e))
  362. @content_completion_router.get("/content_completion_health")
  363. async def health_check():
  364. return {
  365. "status": "healthy",
  366. "provider": "Shutian",
  367. "current_model": CustomAPIConfig.get_model_name(),
  368. "api_url_prefix": "https://dashscope.aliyuncs.com/compatible-mode/v1"
  369. }
  370. @content_completion_router.get("/content_completion_modes", response_model=ContentCompletionResponse)
  371. async def get_modes():
  372. modes = [
  373. {"mode": "continue", "name": "续写"}, {"mode": "expand", "name": "扩写"},
  374. {"mode": "polish", "name": "润色"}, {"mode": "complete", "name": "补全"}
  375. ]
  376. return ContentCompletionResponse(code=200, message="success", data={"modes": modes})
  377. @content_completion_router.get("/content_completion_api_status", response_model=ContentCompletionResponse)
  378. async def get_api_status():
  379. enabled = CustomAPIConfig.is_enabled()
  380. return ContentCompletionResponse(
  381. code=200, message="success",
  382. data={
  383. "enabled": enabled,
  384. "provider": "Shutian",
  385. "model": CustomAPIConfig.get_model_name()
  386. }
  387. )