report_compat.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. """
  2. 报告兼容路由
  3. 完全对齐 Go 版本的接口实现,保持外部一致性
  4. """
  5. from fastapi import APIRouter, Request
  6. from fastapi.responses import StreamingResponse, JSONResponse
  7. import httpx
  8. import json
  9. from typing import AsyncGenerator
  10. from models.report import (
  11. ReportCompleteFlowRequest,
  12. UpdateAIMessageRequest,
  13. StopSSERequest,
  14. StreamChatRequest
  15. )
  16. from services.aichat_proxy import aichat_proxy
  17. from utils.token import is_local_token
  18. from utils.config import settings
  19. from utils.logger import logger
  20. from database import get_db
  21. from sqlalchemy.orm import Session
  22. from models.chat import AIMessage
  23. router = APIRouter(prefix="/apiv1", tags=["报告兼容"])
  24. def get_request_token(request: Request) -> str:
  25. """获取请求中的 token"""
  26. # 支持多种 header 名称
  27. for header_name in ["token", "Token", "Authorization"]:
  28. header_value = request.headers.get(header_name, "").strip()
  29. if header_value:
  30. if header_name == "Authorization" and header_value.startswith("Bearer "):
  31. return header_value.replace("Bearer ", "")
  32. return header_value
  33. return ""
  34. def should_proxy_to_aichat(token: str) -> bool:
  35. """判断是否应该代理到 aichat"""
  36. if not token:
  37. return False
  38. # 本地 token 不代理,外部 token 代理
  39. return not is_local_token(token)
  40. def _normalize_proxy_conversation_id(
  41. request_data: ReportCompleteFlowRequest,
  42. request: Request
  43. ) -> int:
  44. """
  45. 规范化转发给 aichat 的 ai_conversation_id。
  46. 新会话必须保持为 0,让 aichat 创建 ai_conversation 主记录并返回真实 ID。
  47. """
  48. conversation_id = int(request_data.ai_conversation_id or 0)
  49. return conversation_id if conversation_id > 0 else 0
  50. def _build_aichat_complete_flow_body(
  51. request_data: ReportCompleteFlowRequest,
  52. request: Request
  53. ) -> bytes:
  54. """构建转发到 aichat 的请求体。"""
  55. payload = request_data.model_dump()
  56. if payload.get("user_id") is None:
  57. user = getattr(request.state, "user", None)
  58. user_id = getattr(user, "user_id", None)
  59. if user_id is not None:
  60. payload["user_id"] = int(user_id)
  61. payload["ai_conversation_id"] = _normalize_proxy_conversation_id(request_data, request)
  62. return json.dumps(payload, ensure_ascii=False).encode("utf-8")
  63. async def fallback_to_local_stream(
  64. request_data: ReportCompleteFlowRequest,
  65. request: Request
  66. ) -> StreamingResponse:
  67. """降级为本地 SSE 兼容输出(匹配前端 report/complete-flow 事件格式)。"""
  68. logger.info("[报告兼容] 降级为本地 SSE 兼容输出")
  69. stream_request = StreamChatRequest(
  70. message=request_data.user_question,
  71. ai_conversation_id=request_data.ai_conversation_id,
  72. business_type=0
  73. )
  74. local_url = f"http://127.0.0.1:{settings.app.port}/apiv1/stream/chat-with-db"
  75. async def stream_generator() -> AsyncGenerator[bytes, None]:
  76. ai_conversation_id = request_data.ai_conversation_id or 0
  77. ai_message_id = 0
  78. full_response = ""
  79. try:
  80. headers = {"Content-Type": "application/json"}
  81. for header_name in ["Authorization", "Token", "token"]:
  82. if header_value := request.headers.get(header_name):
  83. headers[header_name] = header_value
  84. async with httpx.AsyncClient(timeout=600) as client:
  85. async with client.stream(
  86. "POST",
  87. local_url,
  88. json=stream_request.dict(),
  89. headers=headers
  90. ) as response:
  91. if response.status_code != 200:
  92. error_msg = f"data: {{\"type\": \"online_error\", \"message\": \"Local stream failed: {response.status_code}\"}}\n\n"
  93. yield error_msg.encode("utf-8")
  94. yield b"data: {\"type\": \"completed\"}\n\n"
  95. return
  96. async for raw_chunk in response.aiter_bytes(chunk_size=4096):
  97. if not raw_chunk:
  98. continue
  99. chunk_text = raw_chunk.decode("utf-8", errors="ignore")
  100. for line in chunk_text.split("\n"):
  101. if not line.startswith("data: "):
  102. continue
  103. payload = line[6:].strip()
  104. if not payload:
  105. continue
  106. if payload == "[DONE]":
  107. if full_response:
  108. online_answer = {
  109. "type": "online_answer",
  110. "content": full_response,
  111. "ai_conversation_id": ai_conversation_id,
  112. "ai_message_id": ai_message_id,
  113. }
  114. yield f"data: {json.dumps(online_answer, ensure_ascii=False)}\n\n".encode("utf-8")
  115. yield b"data: {\"type\": \"completed\"}\n\n"
  116. return
  117. if payload.startswith("{"):
  118. try:
  119. data = json.loads(payload)
  120. except Exception:
  121. data = None
  122. if isinstance(data, dict) and data.get("type") == "initial":
  123. ai_conversation_id = data.get("ai_conversation_id") or ai_conversation_id
  124. ai_message_id = data.get("ai_message_id") or ai_message_id
  125. continue
  126. full_response += payload.replace("\\n", "\n")
  127. except Exception as e:
  128. logger.error(f"[报告兼容] 本地 SSE 处理异常: {e}")
  129. error_msg = f"data: {{\"type\": \"online_error\", \"message\": \"Local stream error: {str(e)}\"}}\n\n"
  130. yield error_msg.encode("utf-8")
  131. yield b"data: {\"type\": \"completed\"}\n\n"
  132. return StreamingResponse(
  133. stream_generator(),
  134. media_type="text/event-stream",
  135. headers={
  136. "Cache-Control": "no-cache",
  137. "Connection": "keep-alive",
  138. "Access-Control-Allow-Origin": "*",
  139. }
  140. )
  141. @router.post("/report/complete-flow")
  142. async def complete_flow(request: Request):
  143. """
  144. 完整报告生成流程(SSE)
  145. 完全对齐 Go 版本的实现
  146. """
  147. # 解析请求体
  148. request_body = await request.body()
  149. try:
  150. request_data = ReportCompleteFlowRequest(**json.loads(request_body))
  151. except Exception as e:
  152. return StreamingResponse(
  153. iter([
  154. f"data: {{\"type\": \"online_error\", \"message\": \"Request parse error: {str(e)}\"}}\n\n".encode('utf-8'),
  155. b"data: {\"type\": \"completed\"}\n\n"
  156. ]),
  157. media_type="text/event-stream"
  158. )
  159. # 验证问题不为空
  160. if not request_data.user_question.strip():
  161. return StreamingResponse(
  162. iter([
  163. b"data: {\"type\": \"online_error\", \"message\": \"Question cannot be empty\"}\n\n",
  164. b"data: {\"type\": \"completed\"}\n\n"
  165. ]),
  166. media_type="text/event-stream"
  167. )
  168. token = get_request_token(request)
  169. if should_proxy_to_aichat(token):
  170. proxy_body = _build_aichat_complete_flow_body(request_data, request)
  171. if int(request_data.ai_conversation_id or 0) <= 0:
  172. logger.info("[报告兼容] 新会话保持 ai_conversation_id=0,交由 aichat 创建历史记录")
  173. else:
  174. logger.info("[报告兼容] 代理 complete-flow 到 aichat")
  175. try:
  176. return await aichat_proxy.proxy_sse("/report/complete-flow", request, proxy_body)
  177. except Exception as e:
  178. logger.error(f"[报告兼容] 代理 complete-flow 到 aichat 失败,降级本地: {e}")
  179. # 本地 token 或代理失败:降级到本地兼容 SSE
  180. return await fallback_to_local_stream(request_data, request)
  181. @router.post("/report/update-ai-message")
  182. async def update_ai_message(request: Request):
  183. """
  184. 更新 AI 消息内容
  185. 完全对齐 Go 版本的实现
  186. """
  187. request_body = await request.body()
  188. # 获取 token 并判断路由策略
  189. token = get_request_token(request)
  190. if should_proxy_to_aichat(token):
  191. # 外部 token,代理到 aichat
  192. logger.info("[报告兼容] 代理更新消息到 aichat")
  193. try:
  194. return await aichat_proxy.proxy_json("/report/update-ai-message", request, request_body)
  195. except Exception as e:
  196. logger.error(f"[报告兼容] 代理更新消息失败: {e}")
  197. # 降级到本地处理
  198. # 本地处理
  199. try:
  200. request_data = UpdateAIMessageRequest(**json.loads(request_body))
  201. except Exception as e:
  202. return JSONResponse(
  203. content={"success": False, "message": f"Request parse error: {str(e)}"},
  204. status_code=400
  205. )
  206. if request_data.ai_message_id == 0:
  207. return JSONResponse(
  208. content={"success": False, "message": "ai_message_id cannot be empty"},
  209. status_code=400
  210. )
  211. # 更新数据库
  212. try:
  213. db: Session = next(get_db())
  214. db.query(AIMessage).filter(
  215. AIMessage.id == request_data.ai_message_id,
  216. AIMessage.is_deleted == 0
  217. ).update({"content": request_data.content})
  218. db.commit()
  219. return JSONResponse(
  220. content={"success": True, "message": "AI message updated"},
  221. status_code=200
  222. )
  223. except Exception as e:
  224. logger.error(f"[报告兼容] 更新消息失败: {e}")
  225. return JSONResponse(
  226. content={"success": False, "message": f"Update failed: {str(e)}"},
  227. status_code=500
  228. )
  229. @router.post("/sse/stop")
  230. async def stop_sse(request: Request):
  231. """
  232. 停止 SSE 流
  233. 完全对齐 Go 版本的实现
  234. """
  235. request_body = await request.body()
  236. # 获取 token 并判断路由策略
  237. token = get_request_token(request)
  238. if should_proxy_to_aichat(token):
  239. # 外部 token,代理到 aichat
  240. logger.info("[报告兼容] 代理停止请求到 aichat")
  241. try:
  242. return await aichat_proxy.proxy_json("/sse/stop", request, request_body)
  243. except Exception as e:
  244. logger.error(f"[报告兼容] 代理停止请求失败: {e}")
  245. # 降级到本地处理
  246. # 本地处理(简单返回成功)
  247. try:
  248. request_data = StopSSERequest(**json.loads(request_body))
  249. return JSONResponse(
  250. content={
  251. "success": True,
  252. "message": "Stop request received",
  253. "ai_conversation_id": request_data.ai_conversation_id
  254. },
  255. status_code=200
  256. )
  257. except Exception as e:
  258. return JSONResponse(
  259. content={"success": False, "message": f"Request parse error: {str(e)}"},
  260. status_code=400
  261. )