report_compat.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. """
  2. 报告兼容路由
  3. 完全对齐 Go 版本的接口实现,保持外部一致性
  4. """
  5. from fastapi import APIRouter, Request
  6. from fastapi.responses import StreamingResponse, JSONResponse
  7. import httpx
  8. import json
  9. import time
  10. from typing import AsyncGenerator
  11. from models.report import (
  12. ReportCompleteFlowRequest,
  13. UpdateAIMessageRequest,
  14. StopSSERequest,
  15. StreamChatRequest
  16. )
  17. from services.aichat_proxy import aichat_proxy
  18. from utils.token import is_local_token
  19. from utils.config import settings
  20. from utils.logger import logger
  21. from database import get_db
  22. from sqlalchemy.orm import Session
  23. from models.chat import AIMessage
  24. router = APIRouter(prefix="/apiv1", tags=["报告兼容"])
  25. def get_request_token(request: Request) -> str:
  26. """获取请求中的 token"""
  27. # 支持多种 header 名称
  28. for header_name in ["token", "Token", "Authorization"]:
  29. header_value = request.headers.get(header_name, "").strip()
  30. if header_value:
  31. if header_name == "Authorization" and header_value.startswith("Bearer "):
  32. return header_value.replace("Bearer ", "")
  33. return header_value
  34. return ""
  35. def should_proxy_to_aichat(token: str) -> bool:
  36. """判断是否应该代理到 aichat"""
  37. if not token:
  38. return False
  39. # 本地 token 不代理,外部 token 代理
  40. return not is_local_token(token)
  41. def _ensure_proxy_conversation_id(
  42. request_data: ReportCompleteFlowRequest,
  43. request: Request
  44. ) -> int:
  45. """
  46. 为 aichat 生成可用的 ai_conversation_id。
  47. 目的:避免 ai_conversation_id=0 触发 aichat 内部 transfer_session 崩溃分支。
  48. """
  49. conversation_id = int(request_data.ai_conversation_id or 0)
  50. if conversation_id > 0:
  51. return conversation_id
  52. # 生成一个正数会话ID(<= 2^53,前端 JS 可安全表示)
  53. generated_id = int(time.time() * 1_000_000)
  54. # 叠加少量用户特征,降低并发冲突概率
  55. user = getattr(request.state, "user", None)
  56. user_id = int(getattr(user, "user_id", 0) or 0)
  57. if user_id > 0:
  58. generated_id = generated_id - (generated_id % 1000) + (user_id % 1000)
  59. if generated_id <= 0:
  60. generated_id = int(time.time()) or 1
  61. return generated_id
  62. def _build_aichat_complete_flow_body(
  63. request_data: ReportCompleteFlowRequest,
  64. request: Request
  65. ) -> bytes:
  66. """构建转发到 aichat 的请求体,并规避 ai_conversation_id=0 的崩溃场景。"""
  67. payload = request_data.dict()
  68. payload["ai_conversation_id"] = _ensure_proxy_conversation_id(request_data, request)
  69. return json.dumps(payload, ensure_ascii=False).encode("utf-8")
  70. async def fallback_to_local_stream(
  71. request_data: ReportCompleteFlowRequest,
  72. request: Request
  73. ) -> StreamingResponse:
  74. """降级为本地 SSE 兼容输出(匹配前端 report/complete-flow 事件格式)。"""
  75. logger.info("[报告兼容] 降级为本地 SSE 兼容输出")
  76. stream_request = StreamChatRequest(
  77. message=request_data.user_question,
  78. ai_conversation_id=request_data.ai_conversation_id,
  79. business_type=0
  80. )
  81. local_url = f"http://127.0.0.1:{settings.app.port}/apiv1/stream/chat-with-db"
  82. async def stream_generator() -> AsyncGenerator[bytes, None]:
  83. ai_conversation_id = request_data.ai_conversation_id or 0
  84. ai_message_id = 0
  85. full_response = ""
  86. try:
  87. headers = {"Content-Type": "application/json"}
  88. for header_name in ["Authorization", "Token", "token"]:
  89. if header_value := request.headers.get(header_name):
  90. headers[header_name] = header_value
  91. async with httpx.AsyncClient(timeout=600) as client:
  92. async with client.stream(
  93. "POST",
  94. local_url,
  95. json=stream_request.dict(),
  96. headers=headers
  97. ) as response:
  98. if response.status_code != 200:
  99. error_msg = f"data: {{\"type\": \"online_error\", \"message\": \"Local stream failed: {response.status_code}\"}}\n\n"
  100. yield error_msg.encode("utf-8")
  101. yield b"data: {\"type\": \"completed\"}\n\n"
  102. return
  103. async for raw_chunk in response.aiter_bytes(chunk_size=4096):
  104. if not raw_chunk:
  105. continue
  106. chunk_text = raw_chunk.decode("utf-8", errors="ignore")
  107. for line in chunk_text.split("\n"):
  108. if not line.startswith("data: "):
  109. continue
  110. payload = line[6:].strip()
  111. if not payload:
  112. continue
  113. if payload == "[DONE]":
  114. if full_response:
  115. online_answer = {
  116. "type": "online_answer",
  117. "content": full_response,
  118. "ai_conversation_id": ai_conversation_id,
  119. "ai_message_id": ai_message_id,
  120. }
  121. yield f"data: {json.dumps(online_answer, ensure_ascii=False)}\n\n".encode("utf-8")
  122. yield b"data: {\"type\": \"completed\"}\n\n"
  123. return
  124. if payload.startswith("{"):
  125. try:
  126. data = json.loads(payload)
  127. except Exception:
  128. data = None
  129. if isinstance(data, dict) and data.get("type") == "initial":
  130. ai_conversation_id = data.get("ai_conversation_id") or ai_conversation_id
  131. ai_message_id = data.get("ai_message_id") or ai_message_id
  132. continue
  133. full_response += payload.replace("\\n", "\n")
  134. except Exception as e:
  135. logger.error(f"[报告兼容] 本地 SSE 处理异常: {e}")
  136. error_msg = f"data: {{\"type\": \"online_error\", \"message\": \"Local stream error: {str(e)}\"}}\n\n"
  137. yield error_msg.encode("utf-8")
  138. yield b"data: {\"type\": \"completed\"}\n\n"
  139. return StreamingResponse(
  140. stream_generator(),
  141. media_type="text/event-stream",
  142. headers={
  143. "Cache-Control": "no-cache",
  144. "Connection": "keep-alive",
  145. "Access-Control-Allow-Origin": "*",
  146. }
  147. )
  148. @router.post("/report/complete-flow")
  149. async def complete_flow(request: Request):
  150. """
  151. 完整报告生成流程(SSE)
  152. 完全对齐 Go 版本的实现
  153. """
  154. # 解析请求体
  155. request_body = await request.body()
  156. try:
  157. request_data = ReportCompleteFlowRequest(**json.loads(request_body))
  158. except Exception as e:
  159. return StreamingResponse(
  160. iter([
  161. f"data: {{\"type\": \"online_error\", \"message\": \"Request parse error: {str(e)}\"}}\n\n".encode('utf-8'),
  162. b"data: {\"type\": \"completed\"}\n\n"
  163. ]),
  164. media_type="text/event-stream"
  165. )
  166. # 验证问题不为空
  167. if not request_data.user_question.strip():
  168. return StreamingResponse(
  169. iter([
  170. b"data: {\"type\": \"online_error\", \"message\": \"Question cannot be empty\"}\n\n",
  171. b"data: {\"type\": \"completed\"}\n\n"
  172. ]),
  173. media_type="text/event-stream"
  174. )
  175. token = get_request_token(request)
  176. if should_proxy_to_aichat(token):
  177. proxy_body = _build_aichat_complete_flow_body(request_data, request)
  178. if int(request_data.ai_conversation_id or 0) <= 0:
  179. proxy_data = json.loads(proxy_body.decode("utf-8"))
  180. logger.info(
  181. f"[报告兼容] ai_conversation_id=0 已重写为 {proxy_data.get('ai_conversation_id')},代理到 aichat"
  182. )
  183. else:
  184. logger.info("[报告兼容] 代理 complete-flow 到 aichat")
  185. try:
  186. return await aichat_proxy.proxy_sse("/report/complete-flow", request, proxy_body)
  187. except Exception as e:
  188. logger.error(f"[报告兼容] 代理 complete-flow 到 aichat 失败,降级本地: {e}")
  189. # 本地 token 或代理失败:降级到本地兼容 SSE
  190. return await fallback_to_local_stream(request_data, request)
  191. @router.post("/report/update-ai-message")
  192. async def update_ai_message(request: Request):
  193. """
  194. 更新 AI 消息内容
  195. 完全对齐 Go 版本的实现
  196. """
  197. request_body = await request.body()
  198. # 获取 token 并判断路由策略
  199. token = get_request_token(request)
  200. if should_proxy_to_aichat(token):
  201. # 外部 token,代理到 aichat
  202. logger.info("[报告兼容] 代理更新消息到 aichat")
  203. try:
  204. return await aichat_proxy.proxy_json("/report/update-ai-message", request, request_body)
  205. except Exception as e:
  206. logger.error(f"[报告兼容] 代理更新消息失败: {e}")
  207. # 降级到本地处理
  208. # 本地处理
  209. try:
  210. request_data = UpdateAIMessageRequest(**json.loads(request_body))
  211. except Exception as e:
  212. return JSONResponse(
  213. content={"success": False, "message": f"Request parse error: {str(e)}"},
  214. status_code=400
  215. )
  216. if request_data.ai_message_id == 0:
  217. return JSONResponse(
  218. content={"success": False, "message": "ai_message_id cannot be empty"},
  219. status_code=400
  220. )
  221. # 更新数据库
  222. try:
  223. db: Session = next(get_db())
  224. db.query(AIMessage).filter(
  225. AIMessage.id == request_data.ai_message_id,
  226. AIMessage.is_deleted == 0
  227. ).update({"content": request_data.content})
  228. db.commit()
  229. return JSONResponse(
  230. content={"success": True, "message": "AI message updated"},
  231. status_code=200
  232. )
  233. except Exception as e:
  234. logger.error(f"[报告兼容] 更新消息失败: {e}")
  235. return JSONResponse(
  236. content={"success": False, "message": f"Update failed: {str(e)}"},
  237. status_code=500
  238. )
  239. @router.post("/sse/stop")
  240. async def stop_sse(request: Request):
  241. """
  242. 停止 SSE 流
  243. 完全对齐 Go 版本的实现
  244. """
  245. request_body = await request.body()
  246. # 获取 token 并判断路由策略
  247. token = get_request_token(request)
  248. if should_proxy_to_aichat(token):
  249. # 外部 token,代理到 aichat
  250. logger.info("[报告兼容] 代理停止请求到 aichat")
  251. try:
  252. return await aichat_proxy.proxy_json("/sse/stop", request, request_body)
  253. except Exception as e:
  254. logger.error(f"[报告兼容] 代理停止请求失败: {e}")
  255. # 降级到本地处理
  256. # 本地处理(简单返回成功)
  257. try:
  258. request_data = StopSSERequest(**json.loads(request_body))
  259. return JSONResponse(
  260. content={
  261. "success": True,
  262. "message": "Stop request received",
  263. "ai_conversation_id": request_data.ai_conversation_id
  264. },
  265. status_code=200
  266. )
  267. except Exception as e:
  268. return JSONResponse(
  269. content={"success": False, "message": f"Request parse error: {str(e)}"},
  270. status_code=400
  271. )