""" 报告兼容路由 完全对齐 Go 版本的接口实现,保持外部一致性 """ from fastapi import APIRouter, Request from fastapi.responses import StreamingResponse, JSONResponse import httpx import json import time from typing import AsyncGenerator from models.report import ( ReportCompleteFlowRequest, UpdateAIMessageRequest, StopSSERequest, StreamChatRequest ) from services.aichat_proxy import aichat_proxy from utils.token import is_local_token from utils.config import settings from utils.logger import logger from database import get_db from sqlalchemy.orm import Session from models.chat import AIMessage router = APIRouter(prefix="/apiv1", tags=["报告兼容"]) def get_request_token(request: Request) -> str: """获取请求中的 token""" # 支持多种 header 名称 for header_name in ["token", "Token", "Authorization"]: header_value = request.headers.get(header_name, "").strip() if header_value: if header_name == "Authorization" and header_value.startswith("Bearer "): return header_value.replace("Bearer ", "") return header_value return "" def should_proxy_to_aichat(token: str) -> bool: """判断是否应该代理到 aichat""" if not token: return False # 本地 token 不代理,外部 token 代理 return not is_local_token(token) def _ensure_proxy_conversation_id( request_data: ReportCompleteFlowRequest, request: Request ) -> int: """ 为 aichat 生成可用的 ai_conversation_id。 目的:避免 ai_conversation_id=0 触发 aichat 内部 transfer_session 崩溃分支。 """ conversation_id = int(request_data.ai_conversation_id or 0) if conversation_id > 0: return conversation_id # 生成一个正数会话ID(<= 2^53,前端 JS 可安全表示) generated_id = int(time.time() * 1_000_000) # 叠加少量用户特征,降低并发冲突概率 user = getattr(request.state, "user", None) user_id = int(getattr(user, "user_id", 0) or 0) if user_id > 0: generated_id = generated_id - (generated_id % 1000) + (user_id % 1000) if generated_id <= 0: generated_id = int(time.time()) or 1 return generated_id def _build_aichat_complete_flow_body( request_data: ReportCompleteFlowRequest, request: Request ) -> bytes: """构建转发到 aichat 的请求体,并规避 ai_conversation_id=0 的崩溃场景。""" payload = request_data.dict() payload["ai_conversation_id"] = _ensure_proxy_conversation_id(request_data, request) return json.dumps(payload, ensure_ascii=False).encode("utf-8") async def fallback_to_local_stream( request_data: ReportCompleteFlowRequest, request: Request ) -> StreamingResponse: """降级为本地 SSE 兼容输出(匹配前端 report/complete-flow 事件格式)。""" logger.info("[报告兼容] 降级为本地 SSE 兼容输出") stream_request = StreamChatRequest( message=request_data.user_question, ai_conversation_id=request_data.ai_conversation_id, business_type=0 ) local_url = f"http://127.0.0.1:{settings.app.port}/apiv1/stream/chat-with-db" async def stream_generator() -> AsyncGenerator[bytes, None]: ai_conversation_id = request_data.ai_conversation_id or 0 ai_message_id = 0 full_response = "" try: headers = {"Content-Type": "application/json"} for header_name in ["Authorization", "Token", "token"]: if header_value := request.headers.get(header_name): headers[header_name] = header_value async with httpx.AsyncClient(timeout=600) as client: async with client.stream( "POST", local_url, json=stream_request.dict(), headers=headers ) as response: if response.status_code != 200: error_msg = f"data: {{\"type\": \"online_error\", \"message\": \"Local stream failed: {response.status_code}\"}}\n\n" yield error_msg.encode("utf-8") yield b"data: {\"type\": \"completed\"}\n\n" return async for raw_chunk in response.aiter_bytes(chunk_size=4096): if not raw_chunk: continue chunk_text = raw_chunk.decode("utf-8", errors="ignore") for line in chunk_text.split("\n"): if not line.startswith("data: "): continue payload = line[6:].strip() if not payload: continue if payload == "[DONE]": if full_response: online_answer = { "type": "online_answer", "content": full_response, "ai_conversation_id": ai_conversation_id, "ai_message_id": ai_message_id, } yield f"data: {json.dumps(online_answer, ensure_ascii=False)}\n\n".encode("utf-8") yield b"data: {\"type\": \"completed\"}\n\n" return if payload.startswith("{"): try: data = json.loads(payload) except Exception: data = None if isinstance(data, dict) and data.get("type") == "initial": ai_conversation_id = data.get("ai_conversation_id") or ai_conversation_id ai_message_id = data.get("ai_message_id") or ai_message_id continue full_response += payload.replace("\\n", "\n") except Exception as e: logger.error(f"[报告兼容] 本地 SSE 处理异常: {e}") error_msg = f"data: {{\"type\": \"online_error\", \"message\": \"Local stream error: {str(e)}\"}}\n\n" yield error_msg.encode("utf-8") yield b"data: {\"type\": \"completed\"}\n\n" return StreamingResponse( stream_generator(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "Access-Control-Allow-Origin": "*", } ) @router.post("/report/complete-flow") async def complete_flow(request: Request): """ 完整报告生成流程(SSE) 完全对齐 Go 版本的实现 """ # 解析请求体 request_body = await request.body() try: request_data = ReportCompleteFlowRequest(**json.loads(request_body)) except Exception as e: return StreamingResponse( iter([ f"data: {{\"type\": \"online_error\", \"message\": \"Request parse error: {str(e)}\"}}\n\n".encode('utf-8'), b"data: {\"type\": \"completed\"}\n\n" ]), media_type="text/event-stream" ) # 验证问题不为空 if not request_data.user_question.strip(): return StreamingResponse( iter([ b"data: {\"type\": \"online_error\", \"message\": \"Question cannot be empty\"}\n\n", b"data: {\"type\": \"completed\"}\n\n" ]), media_type="text/event-stream" ) token = get_request_token(request) if should_proxy_to_aichat(token): proxy_body = _build_aichat_complete_flow_body(request_data, request) if int(request_data.ai_conversation_id or 0) <= 0: proxy_data = json.loads(proxy_body.decode("utf-8")) logger.info( f"[报告兼容] ai_conversation_id=0 已重写为 {proxy_data.get('ai_conversation_id')},代理到 aichat" ) else: logger.info("[报告兼容] 代理 complete-flow 到 aichat") try: return await aichat_proxy.proxy_sse("/report/complete-flow", request, proxy_body) except Exception as e: logger.error(f"[报告兼容] 代理 complete-flow 到 aichat 失败,降级本地: {e}") # 本地 token 或代理失败:降级到本地兼容 SSE return await fallback_to_local_stream(request_data, request) @router.post("/report/update-ai-message") async def update_ai_message(request: Request): """ 更新 AI 消息内容 完全对齐 Go 版本的实现 """ request_body = await request.body() # 获取 token 并判断路由策略 token = get_request_token(request) if should_proxy_to_aichat(token): # 外部 token,代理到 aichat logger.info("[报告兼容] 代理更新消息到 aichat") try: return await aichat_proxy.proxy_json("/report/update-ai-message", request, request_body) except Exception as e: logger.error(f"[报告兼容] 代理更新消息失败: {e}") # 降级到本地处理 # 本地处理 try: request_data = UpdateAIMessageRequest(**json.loads(request_body)) except Exception as e: return JSONResponse( content={"success": False, "message": f"Request parse error: {str(e)}"}, status_code=400 ) if request_data.ai_message_id == 0: return JSONResponse( content={"success": False, "message": "ai_message_id cannot be empty"}, status_code=400 ) # 更新数据库 try: db: Session = next(get_db()) db.query(AIMessage).filter( AIMessage.id == request_data.ai_message_id, AIMessage.is_deleted == 0 ).update({"content": request_data.content}) db.commit() return JSONResponse( content={"success": True, "message": "AI message updated"}, status_code=200 ) except Exception as e: logger.error(f"[报告兼容] 更新消息失败: {e}") return JSONResponse( content={"success": False, "message": f"Update failed: {str(e)}"}, status_code=500 ) @router.post("/sse/stop") async def stop_sse(request: Request): """ 停止 SSE 流 完全对齐 Go 版本的实现 """ request_body = await request.body() # 获取 token 并判断路由策略 token = get_request_token(request) if should_proxy_to_aichat(token): # 外部 token,代理到 aichat logger.info("[报告兼容] 代理停止请求到 aichat") try: return await aichat_proxy.proxy_json("/sse/stop", request, request_body) except Exception as e: logger.error(f"[报告兼容] 代理停止请求失败: {e}") # 降级到本地处理 # 本地处理(简单返回成功) try: request_data = StopSSERequest(**json.loads(request_body)) return JSONResponse( content={ "success": True, "message": "Stop request received", "ai_conversation_id": request_data.ai_conversation_id }, status_code=200 ) except Exception as e: return JSONResponse( content={"success": False, "message": f"Request parse error: {str(e)}"}, status_code=400 )