|
@@ -0,0 +1,245 @@
|
|
|
|
|
+"""
|
|
|
|
|
+报告兼容路由
|
|
|
|
|
+完全对齐 Go 版本的接口实现,保持外部一致性
|
|
|
|
|
+"""
|
|
|
|
|
+from fastapi import APIRouter, Request
|
|
|
|
|
+from fastapi.responses import StreamingResponse, JSONResponse
|
|
|
|
|
+import httpx
|
|
|
|
|
+import json
|
|
|
|
|
+from typing import AsyncGenerator
|
|
|
|
|
+from models.report import (
|
|
|
|
|
+ ReportCompleteFlowRequest,
|
|
|
|
|
+ UpdateAIMessageRequest,
|
|
|
|
|
+ StopSSERequest,
|
|
|
|
|
+ StreamChatRequest
|
|
|
|
|
+)
|
|
|
|
|
+from services.aichat_proxy import aichat_proxy
|
|
|
|
|
+from utils.token import is_local_token
|
|
|
|
|
+from utils.config import settings
|
|
|
|
|
+from utils.logger import logger
|
|
|
|
|
+from database import get_db
|
|
|
|
|
+from sqlalchemy.orm import Session
|
|
|
|
|
+from models.chat import AIMessage
|
|
|
|
|
+
|
|
|
|
|
+router = APIRouter(prefix="/apiv1", tags=["报告兼容"])
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def get_request_token(request: Request) -> str:
|
|
|
|
|
+ """获取请求中的 token"""
|
|
|
|
|
+ # 支持多种 header 名称
|
|
|
|
|
+ for header_name in ["token", "Token", "Authorization"]:
|
|
|
|
|
+ header_value = request.headers.get(header_name, "").strip()
|
|
|
|
|
+ if header_value:
|
|
|
|
|
+ if header_name == "Authorization" and header_value.startswith("Bearer "):
|
|
|
|
|
+ return header_value.replace("Bearer ", "")
|
|
|
|
|
+ return header_value
|
|
|
|
|
+ return ""
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+def should_proxy_to_aichat(token: str) -> bool:
|
|
|
|
|
+ """判断是否应该代理到 aichat"""
|
|
|
|
|
+ if not token:
|
|
|
|
|
+ return False
|
|
|
|
|
+
|
|
|
|
|
+ # 本地 token 不代理,外部 token 代理
|
|
|
|
|
+ return not is_local_token(token)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+async def fallback_to_local_stream(
|
|
|
|
|
+ request_data: ReportCompleteFlowRequest,
|
|
|
|
|
+ request: Request
|
|
|
|
|
+) -> StreamingResponse:
|
|
|
|
|
+ """降级到本地流式聊天"""
|
|
|
|
|
+ logger.info("[报告兼容] 降级到本地流式聊天")
|
|
|
|
|
+
|
|
|
|
|
+ # 构建本地流式聊天请求
|
|
|
|
|
+ stream_request = StreamChatRequest(
|
|
|
|
|
+ message=request_data.user_question,
|
|
|
|
|
+ ai_conversation_id=request_data.ai_conversation_id,
|
|
|
|
|
+ business_type=0
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 调用本地流式聊天接口
|
|
|
|
|
+ local_url = f"http://127.0.0.1:{settings.app.port}/apiv1/stream/chat-with-db"
|
|
|
|
|
+
|
|
|
|
|
+ async def stream_generator() -> AsyncGenerator[bytes, None]:
|
|
|
|
|
+ try:
|
|
|
|
|
+ # 转发认证 headers
|
|
|
|
|
+ headers = {"Content-Type": "application/json"}
|
|
|
|
|
+ for header_name in ["Authorization", "Token", "token"]:
|
|
|
|
|
+ if header_value := request.headers.get(header_name):
|
|
|
|
|
+ headers[header_name] = header_value
|
|
|
|
|
+
|
|
|
|
|
+ async with httpx.AsyncClient(timeout=600) as client:
|
|
|
|
|
+ async with client.stream(
|
|
|
|
|
+ "POST",
|
|
|
|
|
+ local_url,
|
|
|
|
|
+ json=stream_request.dict(),
|
|
|
|
|
+ headers=headers
|
|
|
|
|
+ ) as response:
|
|
|
|
|
+ if response.status_code != 200:
|
|
|
|
|
+ error_msg = f"data: {{\"type\": \"online_error\", \"message\": \"Local stream failed: {response.status_code}\"}}\n\n"
|
|
|
|
|
+ yield error_msg.encode('utf-8')
|
|
|
|
|
+ yield b"data: {\"type\": \"completed\"}\n\n"
|
|
|
|
|
+ return
|
|
|
|
|
+
|
|
|
|
|
+ # 转发流式响应
|
|
|
|
|
+ async for chunk in response.aiter_bytes(chunk_size=4096):
|
|
|
|
|
+ yield chunk
|
|
|
|
|
+
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"[报告兼容] 本地流式聊天异常: {e}")
|
|
|
|
|
+ error_msg = f"data: {{\"type\": \"online_error\", \"message\": \"Local stream error: {str(e)}\"}}\n\n"
|
|
|
|
|
+ yield error_msg.encode('utf-8')
|
|
|
|
|
+ yield b"data: {\"type\": \"completed\"}\n\n"
|
|
|
|
|
+
|
|
|
|
|
+ return StreamingResponse(
|
|
|
|
|
+ stream_generator(),
|
|
|
|
|
+ media_type="text/event-stream",
|
|
|
|
|
+ headers={
|
|
|
|
|
+ "Cache-Control": "no-cache",
|
|
|
|
|
+ "Connection": "keep-alive",
|
|
|
|
|
+ "Access-Control-Allow-Origin": "*",
|
|
|
|
|
+ }
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@router.post("/report/complete-flow")
|
|
|
|
|
+async def complete_flow(request: Request):
|
|
|
|
|
+ """
|
|
|
|
|
+ 完整报告生成流程(SSE)
|
|
|
|
|
+ 完全对齐 Go 版本的实现
|
|
|
|
|
+ """
|
|
|
|
|
+ # 解析请求体
|
|
|
|
|
+ request_body = await request.body()
|
|
|
|
|
+
|
|
|
|
|
+ try:
|
|
|
|
|
+ request_data = ReportCompleteFlowRequest(**json.loads(request_body))
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ return StreamingResponse(
|
|
|
|
|
+ iter([
|
|
|
|
|
+ f"data: {{\"type\": \"online_error\", \"message\": \"Request parse error: {str(e)}\"}}\n\n".encode('utf-8'),
|
|
|
|
|
+ b"data: {\"type\": \"completed\"}\n\n"
|
|
|
|
|
+ ]),
|
|
|
|
|
+ media_type="text/event-stream"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 验证问题不为空
|
|
|
|
|
+ if not request_data.user_question.strip():
|
|
|
|
|
+ return StreamingResponse(
|
|
|
|
|
+ iter([
|
|
|
|
|
+ b"data: {\"type\": \"online_error\", \"message\": \"Question cannot be empty\"}\n\n",
|
|
|
|
|
+ b"data: {\"type\": \"completed\"}\n\n"
|
|
|
|
|
+ ]),
|
|
|
|
|
+ media_type="text/event-stream"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 获取 token 并判断路由策略
|
|
|
|
|
+ token = get_request_token(request)
|
|
|
|
|
+
|
|
|
|
|
+ if should_proxy_to_aichat(token):
|
|
|
|
|
+ # 外部 token,代理到 aichat
|
|
|
|
|
+ logger.info("[报告兼容] 代理到 aichat 服务")
|
|
|
|
|
+ try:
|
|
|
|
|
+ return await aichat_proxy.proxy_sse("/report/complete-flow", request, request_body)
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"[报告兼容] 代理到 aichat 失败: {e}")
|
|
|
|
|
+ # 降级到本地处理
|
|
|
|
|
+ return await fallback_to_local_stream(request_data, request)
|
|
|
|
|
+ else:
|
|
|
|
|
+ # 本地 token,降级到本地流式聊天
|
|
|
|
|
+ return await fallback_to_local_stream(request_data, request)
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@router.post("/report/update-ai-message")
|
|
|
|
|
+async def update_ai_message(request: Request):
|
|
|
|
|
+ """
|
|
|
|
|
+ 更新 AI 消息内容
|
|
|
|
|
+ 完全对齐 Go 版本的实现
|
|
|
|
|
+ """
|
|
|
|
|
+ request_body = await request.body()
|
|
|
|
|
+
|
|
|
|
|
+ # 获取 token 并判断路由策略
|
|
|
|
|
+ token = get_request_token(request)
|
|
|
|
|
+
|
|
|
|
|
+ if should_proxy_to_aichat(token):
|
|
|
|
|
+ # 外部 token,代理到 aichat
|
|
|
|
|
+ logger.info("[报告兼容] 代理更新消息到 aichat")
|
|
|
|
|
+ try:
|
|
|
|
|
+ return await aichat_proxy.proxy_json("/report/update-ai-message", request, request_body)
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"[报告兼容] 代理更新消息失败: {e}")
|
|
|
|
|
+ # 降级到本地处理
|
|
|
|
|
+
|
|
|
|
|
+ # 本地处理
|
|
|
|
|
+ try:
|
|
|
|
|
+ request_data = UpdateAIMessageRequest(**json.loads(request_body))
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ return JSONResponse(
|
|
|
|
|
+ content={"success": False, "message": f"Request parse error: {str(e)}"},
|
|
|
|
|
+ status_code=400
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ if request_data.ai_message_id == 0:
|
|
|
|
|
+ return JSONResponse(
|
|
|
|
|
+ content={"success": False, "message": "ai_message_id cannot be empty"},
|
|
|
|
|
+ status_code=400
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # 更新数据库
|
|
|
|
|
+ try:
|
|
|
|
|
+ db: Session = next(get_db())
|
|
|
|
|
+ db.query(AIMessage).filter(
|
|
|
|
|
+ AIMessage.id == request_data.ai_message_id,
|
|
|
|
|
+ AIMessage.is_deleted == 0
|
|
|
|
|
+ ).update({"content": request_data.content})
|
|
|
|
|
+ db.commit()
|
|
|
|
|
+
|
|
|
|
|
+ return JSONResponse(
|
|
|
|
|
+ content={"success": True, "message": "AI message updated"},
|
|
|
|
|
+ status_code=200
|
|
|
|
|
+ )
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"[报告兼容] 更新消息失败: {e}")
|
|
|
|
|
+ return JSONResponse(
|
|
|
|
|
+ content={"success": False, "message": f"Update failed: {str(e)}"},
|
|
|
|
|
+ status_code=500
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+@router.post("/sse/stop")
|
|
|
|
|
+async def stop_sse(request: Request):
|
|
|
|
|
+ """
|
|
|
|
|
+ 停止 SSE 流
|
|
|
|
|
+ 完全对齐 Go 版本的实现
|
|
|
|
|
+ """
|
|
|
|
|
+ request_body = await request.body()
|
|
|
|
|
+
|
|
|
|
|
+ # 获取 token 并判断路由策略
|
|
|
|
|
+ token = get_request_token(request)
|
|
|
|
|
+
|
|
|
|
|
+ if should_proxy_to_aichat(token):
|
|
|
|
|
+ # 外部 token,代理到 aichat
|
|
|
|
|
+ logger.info("[报告兼容] 代理停止请求到 aichat")
|
|
|
|
|
+ try:
|
|
|
|
|
+ return await aichat_proxy.proxy_json("/sse/stop", request, request_body)
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ logger.error(f"[报告兼容] 代理停止请求失败: {e}")
|
|
|
|
|
+ # 降级到本地处理
|
|
|
|
|
+
|
|
|
|
|
+ # 本地处理(简单返回成功)
|
|
|
|
|
+ try:
|
|
|
|
|
+ request_data = StopSSERequest(**json.loads(request_body))
|
|
|
|
|
+ return JSONResponse(
|
|
|
|
|
+ content={
|
|
|
|
|
+ "success": True,
|
|
|
|
|
+ "message": "Stop request received",
|
|
|
|
|
+ "ai_conversation_id": request_data.ai_conversation_id
|
|
|
|
|
+ },
|
|
|
|
|
+ status_code=200
|
|
|
|
|
+ )
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ return JSONResponse(
|
|
|
|
|
+ content={"success": False, "message": f"Request parse error: {str(e)}"},
|
|
|
|
|
+ status_code=400
|
|
|
|
|
+ )
|