report_compat.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. """
  2. 报告兼容路由
  3. 完全对齐 Go 版本的接口实现,保持外部一致性
  4. """
  5. from fastapi import APIRouter, File, Request, UploadFile
  6. from fastapi.responses import StreamingResponse, JSONResponse
  7. import httpx
  8. import json
  9. from typing import AsyncGenerator
  10. from models.report import (
  11. ReportCompleteFlowRequest,
  12. UpdateAIMessageRequest,
  13. StopSSERequest,
  14. StreamChatRequest
  15. )
  16. from services.aichat_proxy import aichat_proxy
  17. from utils.token import is_local_token
  18. from utils.config import settings
  19. from utils.logger import logger
  20. from database import get_db
  21. from sqlalchemy.orm import Session
  22. from models.chat import AIMessage
  23. router = APIRouter(prefix="/apiv1", tags=["报告兼容"])
  24. def get_request_token(request: Request) -> str:
  25. """获取请求中的 token"""
  26. # 支持多种 header 名称
  27. for header_name in ["token", "Token", "Authorization"]:
  28. header_value = request.headers.get(header_name, "").strip()
  29. if header_value:
  30. if header_name == "Authorization" and header_value.startswith("Bearer "):
  31. return header_value.replace("Bearer ", "")
  32. return header_value
  33. return ""
  34. def should_proxy_to_aichat(token: str) -> bool:
  35. """判断是否应该代理到 aichat"""
  36. if not token:
  37. return False
  38. # 本地 token 不代理,外部 token 代理
  39. return not is_local_token(token)
  40. def _normalize_proxy_conversation_id(
  41. request_data: ReportCompleteFlowRequest,
  42. request: Request
  43. ) -> int:
  44. """
  45. 规范化转发给 aichat 的 ai_conversation_id。
  46. 新会话必须保持为 0,让 aichat 创建 ai_conversation 主记录并返回真实 ID。
  47. """
  48. conversation_id = int(request_data.ai_conversation_id or 0)
  49. return conversation_id if conversation_id > 0 else 0
  50. def _build_aichat_complete_flow_body(
  51. request_data: ReportCompleteFlowRequest,
  52. request: Request
  53. ) -> bytes:
  54. """构建转发到 aichat 的请求体。"""
  55. payload = request_data.model_dump()
  56. if payload.get("user_id") is None:
  57. user = getattr(request.state, "user", None)
  58. user_id = getattr(user, "user_id", None)
  59. if user_id is not None:
  60. payload["user_id"] = int(user_id)
  61. payload["ai_conversation_id"] = _normalize_proxy_conversation_id(request_data, request)
  62. return json.dumps(payload, ensure_ascii=False).encode("utf-8")
  63. def _augment_message_with_uploaded_documents(message: str, uploaded_documents) -> str:
  64. parts = [message]
  65. for item in uploaded_documents or []:
  66. content = (item.content or "").strip()
  67. if not item.file_name or not content:
  68. continue
  69. parts.append(
  70. "\n".join([
  71. "【用户上传文档】",
  72. f"文件名:{item.file_name}",
  73. f"文件类型:{item.file_type or '未知'}",
  74. "文档内容:",
  75. content,
  76. ])
  77. )
  78. return "\n\n".join(parts)
  79. async def fallback_to_local_stream(
  80. request_data: ReportCompleteFlowRequest,
  81. request: Request
  82. ) -> StreamingResponse:
  83. """降级为本地 SSE 兼容输出(匹配前端 report/complete-flow 事件格式)。"""
  84. logger.info("[报告兼容] 降级为本地 SSE 兼容输出")
  85. stream_request = StreamChatRequest(
  86. message=_augment_message_with_uploaded_documents(
  87. request_data.user_question,
  88. request_data.uploaded_documents,
  89. ),
  90. ai_conversation_id=request_data.ai_conversation_id,
  91. business_type=0
  92. )
  93. local_url = f"http://127.0.0.1:{settings.app.port}/apiv1/stream/chat-with-db"
  94. async def stream_generator() -> AsyncGenerator[bytes, None]:
  95. ai_conversation_id = request_data.ai_conversation_id or 0
  96. ai_message_id = 0
  97. full_response = ""
  98. try:
  99. headers = {"Content-Type": "application/json"}
  100. for header_name in ["Authorization", "Token", "token"]:
  101. if header_value := request.headers.get(header_name):
  102. headers[header_name] = header_value
  103. async with httpx.AsyncClient(timeout=600) as client:
  104. async with client.stream(
  105. "POST",
  106. local_url,
  107. json=stream_request.dict(),
  108. headers=headers
  109. ) as response:
  110. if response.status_code != 200:
  111. error_msg = f"data: {{\"type\": \"online_error\", \"message\": \"Local stream failed: {response.status_code}\"}}\n\n"
  112. yield error_msg.encode("utf-8")
  113. yield b"data: {\"type\": \"completed\"}\n\n"
  114. return
  115. async for raw_chunk in response.aiter_bytes(chunk_size=4096):
  116. if not raw_chunk:
  117. continue
  118. chunk_text = raw_chunk.decode("utf-8", errors="ignore")
  119. for line in chunk_text.split("\n"):
  120. if not line.startswith("data: "):
  121. continue
  122. payload = line[6:].strip()
  123. if not payload:
  124. continue
  125. if payload == "[DONE]":
  126. if full_response:
  127. online_answer = {
  128. "type": "online_answer",
  129. "content": full_response,
  130. "ai_conversation_id": ai_conversation_id,
  131. "ai_message_id": ai_message_id,
  132. }
  133. yield f"data: {json.dumps(online_answer, ensure_ascii=False)}\n\n".encode("utf-8")
  134. yield b"data: {\"type\": \"completed\"}\n\n"
  135. return
  136. if payload.startswith("{"):
  137. try:
  138. data = json.loads(payload)
  139. except Exception:
  140. data = None
  141. if isinstance(data, dict) and data.get("type") == "initial":
  142. ai_conversation_id = data.get("ai_conversation_id") or ai_conversation_id
  143. ai_message_id = data.get("ai_message_id") or ai_message_id
  144. continue
  145. full_response += payload.replace("\\n", "\n")
  146. except Exception as e:
  147. logger.error(f"[报告兼容] 本地 SSE 处理异常: {e}")
  148. error_msg = f"data: {{\"type\": \"online_error\", \"message\": \"Local stream error: {str(e)}\"}}\n\n"
  149. yield error_msg.encode("utf-8")
  150. yield b"data: {\"type\": \"completed\"}\n\n"
  151. return StreamingResponse(
  152. stream_generator(),
  153. media_type="text/event-stream",
  154. headers={
  155. "Cache-Control": "no-cache",
  156. "Connection": "keep-alive",
  157. "Access-Control-Allow-Origin": "*",
  158. }
  159. )
  160. @router.post("/report/complete-flow")
  161. async def complete_flow(request: Request):
  162. """
  163. 完整报告生成流程(SSE)
  164. 完全对齐 Go 版本的实现
  165. """
  166. # 解析请求体
  167. request_body = await request.body()
  168. try:
  169. request_data = ReportCompleteFlowRequest(**json.loads(request_body))
  170. except Exception as e:
  171. return StreamingResponse(
  172. iter([
  173. f"data: {{\"type\": \"online_error\", \"message\": \"Request parse error: {str(e)}\"}}\n\n".encode('utf-8'),
  174. b"data: {\"type\": \"completed\"}\n\n"
  175. ]),
  176. media_type="text/event-stream"
  177. )
  178. # 验证问题不为空
  179. if not request_data.user_question.strip():
  180. return StreamingResponse(
  181. iter([
  182. b"data: {\"type\": \"online_error\", \"message\": \"Question cannot be empty\"}\n\n",
  183. b"data: {\"type\": \"completed\"}\n\n"
  184. ]),
  185. media_type="text/event-stream"
  186. )
  187. token = get_request_token(request)
  188. if should_proxy_to_aichat(token):
  189. proxy_body = _build_aichat_complete_flow_body(request_data, request)
  190. if int(request_data.ai_conversation_id or 0) <= 0:
  191. logger.info("[报告兼容] 新会话保持 ai_conversation_id=0,交由 aichat 创建历史记录")
  192. else:
  193. logger.info("[报告兼容] 代理 complete-flow 到 aichat")
  194. try:
  195. return await aichat_proxy.proxy_sse("/report/complete-flow", request, proxy_body)
  196. except Exception as e:
  197. logger.error(f"[报告兼容] 代理 complete-flow 到 aichat 失败,降级本地: {e}")
  198. # 本地 token 或代理失败:降级到本地兼容 SSE
  199. return await fallback_to_local_stream(request_data, request)
  200. @router.post("/attachments/parse")
  201. async def parse_attachment(request: Request, file: UploadFile = File(...)):
  202. """Parse an uploaded attachment through aichat."""
  203. user = getattr(request.state, "user", None)
  204. if not user:
  205. return JSONResponse(content={"statusCode": 401, "msg": "未授权"}, status_code=401)
  206. return await aichat_proxy.proxy_upload("/attachments/parse", request, file)
  207. @router.post("/report/update-ai-message")
  208. async def update_ai_message(request: Request):
  209. """
  210. 更新 AI 消息内容
  211. 完全对齐 Go 版本的实现
  212. """
  213. request_body = await request.body()
  214. # 获取 token 并判断路由策略
  215. token = get_request_token(request)
  216. if should_proxy_to_aichat(token):
  217. # 外部 token,代理到 aichat
  218. logger.info("[报告兼容] 代理更新消息到 aichat")
  219. try:
  220. return await aichat_proxy.proxy_json("/report/update-ai-message", request, request_body)
  221. except Exception as e:
  222. logger.error(f"[报告兼容] 代理更新消息失败: {e}")
  223. # 降级到本地处理
  224. # 本地处理
  225. try:
  226. request_data = UpdateAIMessageRequest(**json.loads(request_body))
  227. except Exception as e:
  228. return JSONResponse(
  229. content={"success": False, "message": f"Request parse error: {str(e)}"},
  230. status_code=400
  231. )
  232. if request_data.ai_message_id == 0:
  233. return JSONResponse(
  234. content={"success": False, "message": "ai_message_id cannot be empty"},
  235. status_code=400
  236. )
  237. # 更新数据库
  238. try:
  239. db: Session = next(get_db())
  240. db.query(AIMessage).filter(
  241. AIMessage.id == request_data.ai_message_id,
  242. AIMessage.is_deleted == 0
  243. ).update({"content": request_data.content})
  244. db.commit()
  245. return JSONResponse(
  246. content={"success": True, "message": "AI message updated"},
  247. status_code=200
  248. )
  249. except Exception as e:
  250. logger.error(f"[报告兼容] 更新消息失败: {e}")
  251. return JSONResponse(
  252. content={"success": False, "message": f"Update failed: {str(e)}"},
  253. status_code=500
  254. )
  255. @router.post("/sse/stop")
  256. async def stop_sse(request: Request):
  257. """
  258. 停止 SSE 流
  259. 完全对齐 Go 版本的实现
  260. """
  261. request_body = await request.body()
  262. # 获取 token 并判断路由策略
  263. token = get_request_token(request)
  264. if should_proxy_to_aichat(token):
  265. # 外部 token,代理到 aichat
  266. logger.info("[报告兼容] 代理停止请求到 aichat")
  267. try:
  268. return await aichat_proxy.proxy_json("/sse/stop", request, request_body)
  269. except Exception as e:
  270. logger.error(f"[报告兼容] 代理停止请求失败: {e}")
  271. # 降级到本地处理
  272. # 本地处理(简单返回成功)
  273. try:
  274. request_data = StopSSERequest(**json.loads(request_body))
  275. return JSONResponse(
  276. content={
  277. "success": True,
  278. "message": "Stop request received",
  279. "ai_conversation_id": request_data.ai_conversation_id
  280. },
  281. status_code=200
  282. )
  283. except Exception as e:
  284. return JSONResponse(
  285. content={"success": False, "message": f"Request parse error: {str(e)}"},
  286. status_code=400
  287. )