report_compat.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. """
  2. 报告兼容路由
  3. 完全对齐 Go 版本的接口实现,保持外部一致性
  4. """
  5. from fastapi import APIRouter, Request
  6. from fastapi.responses import StreamingResponse, JSONResponse
  7. import httpx
  8. import json
  9. from typing import AsyncGenerator
  10. from models.report import (
  11. ReportCompleteFlowRequest,
  12. UpdateAIMessageRequest,
  13. StopSSERequest,
  14. StreamChatRequest
  15. )
  16. from services.aichat_proxy import aichat_proxy
  17. from utils.token import is_local_token
  18. from utils.config import settings
  19. from utils.logger import logger
  20. from database import get_db
  21. from sqlalchemy.orm import Session
  22. from models.chat import AIMessage
  23. router = APIRouter(prefix="/apiv1", tags=["报告兼容"])
  24. def get_request_token(request: Request) -> str:
  25. """获取请求中的 token"""
  26. # 支持多种 header 名称
  27. for header_name in ["token", "Token", "Authorization"]:
  28. header_value = request.headers.get(header_name, "").strip()
  29. if header_value:
  30. if header_name == "Authorization" and header_value.startswith("Bearer "):
  31. return header_value.replace("Bearer ", "")
  32. return header_value
  33. return ""
  34. def should_proxy_to_aichat(token: str) -> bool:
  35. """判断是否应该代理到 aichat"""
  36. if not token:
  37. return False
  38. # 本地 token 不代理,外部 token 代理
  39. return not is_local_token(token)
  40. async def fallback_to_local_stream(
  41. request_data: ReportCompleteFlowRequest,
  42. request: Request
  43. ) -> StreamingResponse:
  44. """降级到本地流式聊天"""
  45. logger.info("[报告兼容] 降级到本地流式聊天")
  46. # 构建本地流式聊天请求
  47. stream_request = StreamChatRequest(
  48. message=request_data.user_question,
  49. ai_conversation_id=request_data.ai_conversation_id,
  50. business_type=0
  51. )
  52. # 调用本地流式聊天接口
  53. local_url = f"http://127.0.0.1:{settings.app.port}/apiv1/stream/chat-with-db"
  54. async def stream_generator() -> AsyncGenerator[bytes, None]:
  55. try:
  56. # 转发认证 headers
  57. headers = {"Content-Type": "application/json"}
  58. for header_name in ["Authorization", "Token", "token"]:
  59. if header_value := request.headers.get(header_name):
  60. headers[header_name] = header_value
  61. async with httpx.AsyncClient(timeout=600) as client:
  62. async with client.stream(
  63. "POST",
  64. local_url,
  65. json=stream_request.dict(),
  66. headers=headers
  67. ) as response:
  68. if response.status_code != 200:
  69. error_msg = f"data: {{\"type\": \"online_error\", \"message\": \"Local stream failed: {response.status_code}\"}}\n\n"
  70. yield error_msg.encode('utf-8')
  71. yield b"data: {\"type\": \"completed\"}\n\n"
  72. return
  73. # 转发流式响应
  74. async for chunk in response.aiter_bytes(chunk_size=4096):
  75. yield chunk
  76. except Exception as e:
  77. logger.error(f"[报告兼容] 本地流式聊天异常: {e}")
  78. error_msg = f"data: {{\"type\": \"online_error\", \"message\": \"Local stream error: {str(e)}\"}}\n\n"
  79. yield error_msg.encode('utf-8')
  80. yield b"data: {\"type\": \"completed\"}\n\n"
  81. return StreamingResponse(
  82. stream_generator(),
  83. media_type="text/event-stream",
  84. headers={
  85. "Cache-Control": "no-cache",
  86. "Connection": "keep-alive",
  87. "Access-Control-Allow-Origin": "*",
  88. }
  89. )
  90. @router.post("/report/complete-flow")
  91. async def complete_flow(request: Request):
  92. """
  93. 完整报告生成流程(SSE)
  94. 完全对齐 Go 版本的实现
  95. """
  96. # 解析请求体
  97. request_body = await request.body()
  98. try:
  99. request_data = ReportCompleteFlowRequest(**json.loads(request_body))
  100. except Exception as e:
  101. return StreamingResponse(
  102. iter([
  103. f"data: {{\"type\": \"online_error\", \"message\": \"Request parse error: {str(e)}\"}}\n\n".encode('utf-8'),
  104. b"data: {\"type\": \"completed\"}\n\n"
  105. ]),
  106. media_type="text/event-stream"
  107. )
  108. # 验证问题不为空
  109. if not request_data.user_question.strip():
  110. return StreamingResponse(
  111. iter([
  112. b"data: {\"type\": \"online_error\", \"message\": \"Question cannot be empty\"}\n\n",
  113. b"data: {\"type\": \"completed\"}\n\n"
  114. ]),
  115. media_type="text/event-stream"
  116. )
  117. # 获取 token 并判断路由策略
  118. token = get_request_token(request)
  119. if should_proxy_to_aichat(token):
  120. # 外部 token,代理到 aichat
  121. logger.info("[报告兼容] 代理到 aichat 服务")
  122. try:
  123. return await aichat_proxy.proxy_sse("/report/complete-flow", request, request_body)
  124. except Exception as e:
  125. logger.error(f"[报告兼容] 代理到 aichat 失败: {e}")
  126. # 降级到本地处理
  127. return await fallback_to_local_stream(request_data, request)
  128. else:
  129. # 本地 token,降级到本地流式聊天
  130. return await fallback_to_local_stream(request_data, request)
  131. @router.post("/report/update-ai-message")
  132. async def update_ai_message(request: Request):
  133. """
  134. 更新 AI 消息内容
  135. 完全对齐 Go 版本的实现
  136. """
  137. request_body = await request.body()
  138. # 获取 token 并判断路由策略
  139. token = get_request_token(request)
  140. if should_proxy_to_aichat(token):
  141. # 外部 token,代理到 aichat
  142. logger.info("[报告兼容] 代理更新消息到 aichat")
  143. try:
  144. return await aichat_proxy.proxy_json("/report/update-ai-message", request, request_body)
  145. except Exception as e:
  146. logger.error(f"[报告兼容] 代理更新消息失败: {e}")
  147. # 降级到本地处理
  148. # 本地处理
  149. try:
  150. request_data = UpdateAIMessageRequest(**json.loads(request_body))
  151. except Exception as e:
  152. return JSONResponse(
  153. content={"success": False, "message": f"Request parse error: {str(e)}"},
  154. status_code=400
  155. )
  156. if request_data.ai_message_id == 0:
  157. return JSONResponse(
  158. content={"success": False, "message": "ai_message_id cannot be empty"},
  159. status_code=400
  160. )
  161. # 更新数据库
  162. try:
  163. db: Session = next(get_db())
  164. db.query(AIMessage).filter(
  165. AIMessage.id == request_data.ai_message_id,
  166. AIMessage.is_deleted == 0
  167. ).update({"content": request_data.content})
  168. db.commit()
  169. return JSONResponse(
  170. content={"success": True, "message": "AI message updated"},
  171. status_code=200
  172. )
  173. except Exception as e:
  174. logger.error(f"[报告兼容] 更新消息失败: {e}")
  175. return JSONResponse(
  176. content={"success": False, "message": f"Update failed: {str(e)}"},
  177. status_code=500
  178. )
  179. @router.post("/sse/stop")
  180. async def stop_sse(request: Request):
  181. """
  182. 停止 SSE 流
  183. 完全对齐 Go 版本的实现
  184. """
  185. request_body = await request.body()
  186. # 获取 token 并判断路由策略
  187. token = get_request_token(request)
  188. if should_proxy_to_aichat(token):
  189. # 外部 token,代理到 aichat
  190. logger.info("[报告兼容] 代理停止请求到 aichat")
  191. try:
  192. return await aichat_proxy.proxy_json("/sse/stop", request, request_body)
  193. except Exception as e:
  194. logger.error(f"[报告兼容] 代理停止请求失败: {e}")
  195. # 降级到本地处理
  196. # 本地处理(简单返回成功)
  197. try:
  198. request_data = StopSSERequest(**json.loads(request_body))
  199. return JSONResponse(
  200. content={
  201. "success": True,
  202. "message": "Stop request received",
  203. "ai_conversation_id": request_data.ai_conversation_id
  204. },
  205. status_code=200
  206. )
  207. except Exception as e:
  208. return JSONResponse(
  209. content={"success": False, "message": f"Request parse error: {str(e)}"},
  210. status_code=400
  211. )