""" 报告兼容路由 完全对齐 Go 版本的接口实现,保持外部一致性 """ from fastapi import APIRouter, Request from fastapi.responses import StreamingResponse, JSONResponse import httpx import json from typing import AsyncGenerator from models.report import ( ReportCompleteFlowRequest, UpdateAIMessageRequest, StopSSERequest, StreamChatRequest ) from services.aichat_proxy import aichat_proxy from utils.token import is_local_token from utils.config import settings from utils.logger import logger from database import get_db from sqlalchemy.orm import Session from models.chat import AIMessage router = APIRouter(prefix="/apiv1", tags=["报告兼容"]) def get_request_token(request: Request) -> str: """获取请求中的 token""" # 支持多种 header 名称 for header_name in ["token", "Token", "Authorization"]: header_value = request.headers.get(header_name, "").strip() if header_value: if header_name == "Authorization" and header_value.startswith("Bearer "): return header_value.replace("Bearer ", "") return header_value return "" def should_proxy_to_aichat(token: str) -> bool: """判断是否应该代理到 aichat""" if not token: return False # 本地 token 不代理,外部 token 代理 return not is_local_token(token) async def fallback_to_local_stream( request_data: ReportCompleteFlowRequest, request: Request ) -> StreamingResponse: """降级到本地流式聊天""" logger.info("[报告兼容] 降级到本地流式聊天") # 构建本地流式聊天请求 stream_request = StreamChatRequest( message=request_data.user_question, ai_conversation_id=request_data.ai_conversation_id, business_type=0 ) # 调用本地流式聊天接口 local_url = f"http://127.0.0.1:{settings.app.port}/apiv1/stream/chat-with-db" async def stream_generator() -> AsyncGenerator[bytes, None]: try: # 转发认证 headers headers = {"Content-Type": "application/json"} for header_name in ["Authorization", "Token", "token"]: if header_value := request.headers.get(header_name): headers[header_name] = header_value async with httpx.AsyncClient(timeout=600) as client: async with client.stream( "POST", local_url, json=stream_request.dict(), headers=headers ) as response: if response.status_code != 200: error_msg = f"data: {{\"type\": \"online_error\", \"message\": \"Local stream failed: {response.status_code}\"}}\n\n" yield error_msg.encode('utf-8') yield b"data: {\"type\": \"completed\"}\n\n" return # 转发流式响应 async for chunk in response.aiter_bytes(chunk_size=4096): yield chunk except Exception as e: logger.error(f"[报告兼容] 本地流式聊天异常: {e}") error_msg = f"data: {{\"type\": \"online_error\", \"message\": \"Local stream error: {str(e)}\"}}\n\n" yield error_msg.encode('utf-8') yield b"data: {\"type\": \"completed\"}\n\n" return StreamingResponse( stream_generator(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "Access-Control-Allow-Origin": "*", } ) @router.post("/report/complete-flow") async def complete_flow(request: Request): """ 完整报告生成流程(SSE) 完全对齐 Go 版本的实现 """ # 解析请求体 request_body = await request.body() try: request_data = ReportCompleteFlowRequest(**json.loads(request_body)) except Exception as e: return StreamingResponse( iter([ f"data: {{\"type\": \"online_error\", \"message\": \"Request parse error: {str(e)}\"}}\n\n".encode('utf-8'), b"data: {\"type\": \"completed\"}\n\n" ]), media_type="text/event-stream" ) # 验证问题不为空 if not request_data.user_question.strip(): return StreamingResponse( iter([ b"data: {\"type\": \"online_error\", \"message\": \"Question cannot be empty\"}\n\n", b"data: {\"type\": \"completed\"}\n\n" ]), media_type="text/event-stream" ) # 获取 token 并判断路由策略 token = get_request_token(request) if should_proxy_to_aichat(token): # 外部 token,代理到 aichat logger.info("[报告兼容] 代理到 aichat 服务") try: return await aichat_proxy.proxy_sse("/report/complete-flow", request, request_body) except Exception as e: logger.error(f"[报告兼容] 代理到 aichat 失败: {e}") # 降级到本地处理 return await fallback_to_local_stream(request_data, request) else: # 本地 token,降级到本地流式聊天 return await fallback_to_local_stream(request_data, request) @router.post("/report/update-ai-message") async def update_ai_message(request: Request): """ 更新 AI 消息内容 完全对齐 Go 版本的实现 """ request_body = await request.body() # 获取 token 并判断路由策略 token = get_request_token(request) if should_proxy_to_aichat(token): # 外部 token,代理到 aichat logger.info("[报告兼容] 代理更新消息到 aichat") try: return await aichat_proxy.proxy_json("/report/update-ai-message", request, request_body) except Exception as e: logger.error(f"[报告兼容] 代理更新消息失败: {e}") # 降级到本地处理 # 本地处理 try: request_data = UpdateAIMessageRequest(**json.loads(request_body)) except Exception as e: return JSONResponse( content={"success": False, "message": f"Request parse error: {str(e)}"}, status_code=400 ) if request_data.ai_message_id == 0: return JSONResponse( content={"success": False, "message": "ai_message_id cannot be empty"}, status_code=400 ) # 更新数据库 try: db: Session = next(get_db()) db.query(AIMessage).filter( AIMessage.id == request_data.ai_message_id, AIMessage.is_deleted == 0 ).update({"content": request_data.content}) db.commit() return JSONResponse( content={"success": True, "message": "AI message updated"}, status_code=200 ) except Exception as e: logger.error(f"[报告兼容] 更新消息失败: {e}") return JSONResponse( content={"success": False, "message": f"Update failed: {str(e)}"}, status_code=500 ) @router.post("/sse/stop") async def stop_sse(request: Request): """ 停止 SSE 流 完全对齐 Go 版本的实现 """ request_body = await request.body() # 获取 token 并判断路由策略 token = get_request_token(request) if should_proxy_to_aichat(token): # 外部 token,代理到 aichat logger.info("[报告兼容] 代理停止请求到 aichat") try: return await aichat_proxy.proxy_json("/sse/stop", request, request_body) except Exception as e: logger.error(f"[报告兼容] 代理停止请求失败: {e}") # 降级到本地处理 # 本地处理(简单返回成功) try: request_data = StopSSERequest(**json.loads(request_body)) return JSONResponse( content={ "success": True, "message": "Stop request received", "ai_conversation_id": request_data.ai_conversation_id }, status_code=200 ) except Exception as e: return JSONResponse( content={"success": False, "message": f"Request parse error: {str(e)}"}, status_code=400 )