| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320 |
- """
- 报告兼容路由
- 完全对齐 Go 版本的接口实现,保持外部一致性
- """
- from fastapi import APIRouter, Request
- from fastapi.responses import StreamingResponse, JSONResponse
- import httpx
- import json
- import time
- from typing import AsyncGenerator
- from models.report import (
- ReportCompleteFlowRequest,
- UpdateAIMessageRequest,
- StopSSERequest,
- StreamChatRequest
- )
- from services.aichat_proxy import aichat_proxy
- from utils.token import is_local_token
- from utils.config import settings
- from utils.logger import logger
- from database import get_db
- from sqlalchemy.orm import Session
- from models.chat import AIMessage
- router = APIRouter(prefix="/apiv1", tags=["报告兼容"])
- def get_request_token(request: Request) -> str:
- """获取请求中的 token"""
- # 支持多种 header 名称
- for header_name in ["token", "Token", "Authorization"]:
- header_value = request.headers.get(header_name, "").strip()
- if header_value:
- if header_name == "Authorization" and header_value.startswith("Bearer "):
- return header_value.replace("Bearer ", "")
- return header_value
- return ""
- def should_proxy_to_aichat(token: str) -> bool:
- """判断是否应该代理到 aichat"""
- if not token:
- return False
-
- # 本地 token 不代理,外部 token 代理
- return not is_local_token(token)
- def _ensure_proxy_conversation_id(
- request_data: ReportCompleteFlowRequest,
- request: Request
- ) -> int:
- """
- 为 aichat 生成可用的 ai_conversation_id。
- 目的:避免 ai_conversation_id=0 触发 aichat 内部 transfer_session 崩溃分支。
- """
- conversation_id = int(request_data.ai_conversation_id or 0)
- if conversation_id > 0:
- return conversation_id
- # 生成一个正数会话ID(<= 2^53,前端 JS 可安全表示)
- generated_id = int(time.time() * 1_000_000)
- # 叠加少量用户特征,降低并发冲突概率
- user = getattr(request.state, "user", None)
- user_id = int(getattr(user, "user_id", 0) or 0)
- if user_id > 0:
- generated_id = generated_id - (generated_id % 1000) + (user_id % 1000)
- if generated_id <= 0:
- generated_id = int(time.time()) or 1
- return generated_id
- def _build_aichat_complete_flow_body(
- request_data: ReportCompleteFlowRequest,
- request: Request
- ) -> bytes:
- """构建转发到 aichat 的请求体,并规避 ai_conversation_id=0 的崩溃场景。"""
- payload = request_data.dict()
- payload["ai_conversation_id"] = _ensure_proxy_conversation_id(request_data, request)
- return json.dumps(payload, ensure_ascii=False).encode("utf-8")
- async def fallback_to_local_stream(
- request_data: ReportCompleteFlowRequest,
- request: Request
- ) -> StreamingResponse:
- """降级为本地 SSE 兼容输出(匹配前端 report/complete-flow 事件格式)。"""
- logger.info("[报告兼容] 降级为本地 SSE 兼容输出")
- stream_request = StreamChatRequest(
- message=request_data.user_question,
- ai_conversation_id=request_data.ai_conversation_id,
- business_type=0
- )
- local_url = f"http://127.0.0.1:{settings.app.port}/apiv1/stream/chat-with-db"
- async def stream_generator() -> AsyncGenerator[bytes, None]:
- ai_conversation_id = request_data.ai_conversation_id or 0
- ai_message_id = 0
- full_response = ""
- try:
- headers = {"Content-Type": "application/json"}
- for header_name in ["Authorization", "Token", "token"]:
- if header_value := request.headers.get(header_name):
- headers[header_name] = header_value
- async with httpx.AsyncClient(timeout=600) as client:
- async with client.stream(
- "POST",
- local_url,
- json=stream_request.dict(),
- headers=headers
- ) as response:
- if response.status_code != 200:
- error_msg = f"data: {{\"type\": \"online_error\", \"message\": \"Local stream failed: {response.status_code}\"}}\n\n"
- yield error_msg.encode("utf-8")
- yield b"data: {\"type\": \"completed\"}\n\n"
- return
- async for raw_chunk in response.aiter_bytes(chunk_size=4096):
- if not raw_chunk:
- continue
- chunk_text = raw_chunk.decode("utf-8", errors="ignore")
- for line in chunk_text.split("\n"):
- if not line.startswith("data: "):
- continue
- payload = line[6:].strip()
- if not payload:
- continue
- if payload == "[DONE]":
- if full_response:
- online_answer = {
- "type": "online_answer",
- "content": full_response,
- "ai_conversation_id": ai_conversation_id,
- "ai_message_id": ai_message_id,
- }
- yield f"data: {json.dumps(online_answer, ensure_ascii=False)}\n\n".encode("utf-8")
- yield b"data: {\"type\": \"completed\"}\n\n"
- return
- if payload.startswith("{"):
- try:
- data = json.loads(payload)
- except Exception:
- data = None
- if isinstance(data, dict) and data.get("type") == "initial":
- ai_conversation_id = data.get("ai_conversation_id") or ai_conversation_id
- ai_message_id = data.get("ai_message_id") or ai_message_id
- continue
- full_response += payload.replace("\\n", "\n")
- except Exception as e:
- logger.error(f"[报告兼容] 本地 SSE 处理异常: {e}")
- error_msg = f"data: {{\"type\": \"online_error\", \"message\": \"Local stream error: {str(e)}\"}}\n\n"
- yield error_msg.encode("utf-8")
- yield b"data: {\"type\": \"completed\"}\n\n"
- return StreamingResponse(
- stream_generator(),
- media_type="text/event-stream",
- headers={
- "Cache-Control": "no-cache",
- "Connection": "keep-alive",
- "Access-Control-Allow-Origin": "*",
- }
- )
- @router.post("/report/complete-flow")
- async def complete_flow(request: Request):
- """
- 完整报告生成流程(SSE)
- 完全对齐 Go 版本的实现
- """
- # 解析请求体
- request_body = await request.body()
-
- try:
- request_data = ReportCompleteFlowRequest(**json.loads(request_body))
- except Exception as e:
- return StreamingResponse(
- iter([
- f"data: {{\"type\": \"online_error\", \"message\": \"Request parse error: {str(e)}\"}}\n\n".encode('utf-8'),
- b"data: {\"type\": \"completed\"}\n\n"
- ]),
- media_type="text/event-stream"
- )
-
- # 验证问题不为空
- if not request_data.user_question.strip():
- return StreamingResponse(
- iter([
- b"data: {\"type\": \"online_error\", \"message\": \"Question cannot be empty\"}\n\n",
- b"data: {\"type\": \"completed\"}\n\n"
- ]),
- media_type="text/event-stream"
- )
- token = get_request_token(request)
- if should_proxy_to_aichat(token):
- proxy_body = _build_aichat_complete_flow_body(request_data, request)
- if int(request_data.ai_conversation_id or 0) <= 0:
- proxy_data = json.loads(proxy_body.decode("utf-8"))
- logger.info(
- f"[报告兼容] ai_conversation_id=0 已重写为 {proxy_data.get('ai_conversation_id')},代理到 aichat"
- )
- else:
- logger.info("[报告兼容] 代理 complete-flow 到 aichat")
- try:
- return await aichat_proxy.proxy_sse("/report/complete-flow", request, proxy_body)
- except Exception as e:
- logger.error(f"[报告兼容] 代理 complete-flow 到 aichat 失败,降级本地: {e}")
- # 本地 token 或代理失败:降级到本地兼容 SSE
- return await fallback_to_local_stream(request_data, request)
- @router.post("/report/update-ai-message")
- async def update_ai_message(request: Request):
- """
- 更新 AI 消息内容
- 完全对齐 Go 版本的实现
- """
- request_body = await request.body()
-
- # 获取 token 并判断路由策略
- token = get_request_token(request)
-
- if should_proxy_to_aichat(token):
- # 外部 token,代理到 aichat
- logger.info("[报告兼容] 代理更新消息到 aichat")
- try:
- return await aichat_proxy.proxy_json("/report/update-ai-message", request, request_body)
- except Exception as e:
- logger.error(f"[报告兼容] 代理更新消息失败: {e}")
- # 降级到本地处理
-
- # 本地处理
- try:
- request_data = UpdateAIMessageRequest(**json.loads(request_body))
- except Exception as e:
- return JSONResponse(
- content={"success": False, "message": f"Request parse error: {str(e)}"},
- status_code=400
- )
-
- if request_data.ai_message_id == 0:
- return JSONResponse(
- content={"success": False, "message": "ai_message_id cannot be empty"},
- status_code=400
- )
-
- # 更新数据库
- try:
- db: Session = next(get_db())
- db.query(AIMessage).filter(
- AIMessage.id == request_data.ai_message_id,
- AIMessage.is_deleted == 0
- ).update({"content": request_data.content})
- db.commit()
-
- return JSONResponse(
- content={"success": True, "message": "AI message updated"},
- status_code=200
- )
- except Exception as e:
- logger.error(f"[报告兼容] 更新消息失败: {e}")
- return JSONResponse(
- content={"success": False, "message": f"Update failed: {str(e)}"},
- status_code=500
- )
- @router.post("/sse/stop")
- async def stop_sse(request: Request):
- """
- 停止 SSE 流
- 完全对齐 Go 版本的实现
- """
- request_body = await request.body()
-
- # 获取 token 并判断路由策略
- token = get_request_token(request)
-
- if should_proxy_to_aichat(token):
- # 外部 token,代理到 aichat
- logger.info("[报告兼容] 代理停止请求到 aichat")
- try:
- return await aichat_proxy.proxy_json("/sse/stop", request, request_body)
- except Exception as e:
- logger.error(f"[报告兼容] 代理停止请求失败: {e}")
- # 降级到本地处理
-
- # 本地处理(简单返回成功)
- try:
- request_data = StopSSERequest(**json.loads(request_body))
- return JSONResponse(
- content={
- "success": True,
- "message": "Stop request received",
- "ai_conversation_id": request_data.ai_conversation_id
- },
- status_code=200
- )
- except Exception as e:
- return JSONResponse(
- content={"success": False, "message": f"Request parse error: {str(e)}"},
- status_code=400
- )
|