| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245 |
- """
- 报告兼容路由
- 完全对齐 Go 版本的接口实现,保持外部一致性
- """
- from fastapi import APIRouter, Request
- from fastapi.responses import StreamingResponse, JSONResponse
- import httpx
- import json
- from typing import AsyncGenerator
- from models.report import (
- ReportCompleteFlowRequest,
- UpdateAIMessageRequest,
- StopSSERequest,
- StreamChatRequest
- )
- from services.aichat_proxy import aichat_proxy
- from utils.token import is_local_token
- from utils.config import settings
- from utils.logger import logger
- from database import get_db
- from sqlalchemy.orm import Session
- from models.chat import AIMessage
- router = APIRouter(prefix="/apiv1", tags=["报告兼容"])
- def get_request_token(request: Request) -> str:
- """获取请求中的 token"""
- # 支持多种 header 名称
- for header_name in ["token", "Token", "Authorization"]:
- header_value = request.headers.get(header_name, "").strip()
- if header_value:
- if header_name == "Authorization" and header_value.startswith("Bearer "):
- return header_value.replace("Bearer ", "")
- return header_value
- return ""
- def should_proxy_to_aichat(token: str) -> bool:
- """判断是否应该代理到 aichat"""
- if not token:
- return False
-
- # 本地 token 不代理,外部 token 代理
- return not is_local_token(token)
- async def fallback_to_local_stream(
- request_data: ReportCompleteFlowRequest,
- request: Request
- ) -> StreamingResponse:
- """降级到本地流式聊天"""
- logger.info("[报告兼容] 降级到本地流式聊天")
-
- # 构建本地流式聊天请求
- stream_request = StreamChatRequest(
- message=request_data.user_question,
- ai_conversation_id=request_data.ai_conversation_id,
- business_type=0
- )
-
- # 调用本地流式聊天接口
- local_url = f"http://127.0.0.1:{settings.app.port}/apiv1/stream/chat-with-db"
-
- async def stream_generator() -> AsyncGenerator[bytes, None]:
- try:
- # 转发认证 headers
- headers = {"Content-Type": "application/json"}
- for header_name in ["Authorization", "Token", "token"]:
- if header_value := request.headers.get(header_name):
- headers[header_name] = header_value
-
- async with httpx.AsyncClient(timeout=600) as client:
- async with client.stream(
- "POST",
- local_url,
- json=stream_request.dict(),
- headers=headers
- ) as response:
- if response.status_code != 200:
- error_msg = f"data: {{\"type\": \"online_error\", \"message\": \"Local stream failed: {response.status_code}\"}}\n\n"
- yield error_msg.encode('utf-8')
- yield b"data: {\"type\": \"completed\"}\n\n"
- return
-
- # 转发流式响应
- async for chunk in response.aiter_bytes(chunk_size=4096):
- yield chunk
-
- except Exception as e:
- logger.error(f"[报告兼容] 本地流式聊天异常: {e}")
- error_msg = f"data: {{\"type\": \"online_error\", \"message\": \"Local stream error: {str(e)}\"}}\n\n"
- yield error_msg.encode('utf-8')
- yield b"data: {\"type\": \"completed\"}\n\n"
-
- return StreamingResponse(
- stream_generator(),
- media_type="text/event-stream",
- headers={
- "Cache-Control": "no-cache",
- "Connection": "keep-alive",
- "Access-Control-Allow-Origin": "*",
- }
- )
- @router.post("/report/complete-flow")
- async def complete_flow(request: Request):
- """
- 完整报告生成流程(SSE)
- 完全对齐 Go 版本的实现
- """
- # 解析请求体
- request_body = await request.body()
-
- try:
- request_data = ReportCompleteFlowRequest(**json.loads(request_body))
- except Exception as e:
- return StreamingResponse(
- iter([
- f"data: {{\"type\": \"online_error\", \"message\": \"Request parse error: {str(e)}\"}}\n\n".encode('utf-8'),
- b"data: {\"type\": \"completed\"}\n\n"
- ]),
- media_type="text/event-stream"
- )
-
- # 验证问题不为空
- if not request_data.user_question.strip():
- return StreamingResponse(
- iter([
- b"data: {\"type\": \"online_error\", \"message\": \"Question cannot be empty\"}\n\n",
- b"data: {\"type\": \"completed\"}\n\n"
- ]),
- media_type="text/event-stream"
- )
-
- # 获取 token 并判断路由策略
- token = get_request_token(request)
-
- if should_proxy_to_aichat(token):
- # 外部 token,代理到 aichat
- logger.info("[报告兼容] 代理到 aichat 服务")
- try:
- return await aichat_proxy.proxy_sse("/report/complete-flow", request, request_body)
- except Exception as e:
- logger.error(f"[报告兼容] 代理到 aichat 失败: {e}")
- # 降级到本地处理
- return await fallback_to_local_stream(request_data, request)
- else:
- # 本地 token,降级到本地流式聊天
- return await fallback_to_local_stream(request_data, request)
- @router.post("/report/update-ai-message")
- async def update_ai_message(request: Request):
- """
- 更新 AI 消息内容
- 完全对齐 Go 版本的实现
- """
- request_body = await request.body()
-
- # 获取 token 并判断路由策略
- token = get_request_token(request)
-
- if should_proxy_to_aichat(token):
- # 外部 token,代理到 aichat
- logger.info("[报告兼容] 代理更新消息到 aichat")
- try:
- return await aichat_proxy.proxy_json("/report/update-ai-message", request, request_body)
- except Exception as e:
- logger.error(f"[报告兼容] 代理更新消息失败: {e}")
- # 降级到本地处理
-
- # 本地处理
- try:
- request_data = UpdateAIMessageRequest(**json.loads(request_body))
- except Exception as e:
- return JSONResponse(
- content={"success": False, "message": f"Request parse error: {str(e)}"},
- status_code=400
- )
-
- if request_data.ai_message_id == 0:
- return JSONResponse(
- content={"success": False, "message": "ai_message_id cannot be empty"},
- status_code=400
- )
-
- # 更新数据库
- try:
- db: Session = next(get_db())
- db.query(AIMessage).filter(
- AIMessage.id == request_data.ai_message_id,
- AIMessage.is_deleted == 0
- ).update({"content": request_data.content})
- db.commit()
-
- return JSONResponse(
- content={"success": True, "message": "AI message updated"},
- status_code=200
- )
- except Exception as e:
- logger.error(f"[报告兼容] 更新消息失败: {e}")
- return JSONResponse(
- content={"success": False, "message": f"Update failed: {str(e)}"},
- status_code=500
- )
- @router.post("/sse/stop")
- async def stop_sse(request: Request):
- """
- 停止 SSE 流
- 完全对齐 Go 版本的实现
- """
- request_body = await request.body()
-
- # 获取 token 并判断路由策略
- token = get_request_token(request)
-
- if should_proxy_to_aichat(token):
- # 外部 token,代理到 aichat
- logger.info("[报告兼容] 代理停止请求到 aichat")
- try:
- return await aichat_proxy.proxy_json("/sse/stop", request, request_body)
- except Exception as e:
- logger.error(f"[报告兼容] 代理停止请求失败: {e}")
- # 降级到本地处理
-
- # 本地处理(简单返回成功)
- try:
- request_data = StopSSERequest(**json.loads(request_body))
- return JSONResponse(
- content={
- "success": True,
- "message": "Stop request received",
- "ai_conversation_id": request_data.ai_conversation_id
- },
- status_code=200
- )
- except Exception as e:
- return JSONResponse(
- content={"success": False, "message": f"Request parse error: {str(e)}"},
- status_code=400
- )
|