report_compat.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. package controllers
  2. import (
  3. "bufio"
  4. "bytes"
  5. "encoding/json"
  6. "fmt"
  7. "io"
  8. "net/http"
  9. "strings"
  10. "time"
  11. "shudao-chat-go/models"
  12. "shudao-chat-go/utils"
  13. beego "github.com/beego/beego/v2/server/web"
  14. )
  15. // ReportCompatController keeps the old AI chat frontend contract working
  16. // while routing the request to the service that actually implements it.
  17. type ReportCompatController struct {
  18. beego.Controller
  19. }
  20. type reportCompleteFlowRequest struct {
  21. UserQuestion string `json:"user_question"`
  22. WindowSize int `json:"window_size"`
  23. NResults int `json:"n_results"`
  24. AIConversationID uint64 `json:"ai_conversation_id"`
  25. IsNetworkSearchEnabled bool `json:"is_network_search_enabled"`
  26. EnableOnlineModel bool `json:"enable_online_model"`
  27. }
  28. type updateAIMessageRequest struct {
  29. AIMessageID uint64 `json:"ai_message_id"`
  30. Content string `json:"content"`
  31. }
  32. type stopSSERequest struct {
  33. AIConversationID uint64 `json:"ai_conversation_id"`
  34. }
  35. type streamChatAggregateResult struct {
  36. AIConversationID uint64
  37. AIMessageID uint64
  38. Content string
  39. }
  40. func (c *ReportCompatController) CompleteFlow() {
  41. c.setSSEHeaders()
  42. var requestData reportCompleteFlowRequest
  43. if err := json.Unmarshal(c.Ctx.Input.RequestBody, &requestData); err != nil {
  44. c.writeSSEJSON(map[string]interface{}{
  45. "type": "online_error",
  46. "message": fmt.Sprintf("请求参数解析失败: %s", err.Error()),
  47. })
  48. c.writeSSEJSON(map[string]interface{}{"type": "completed"})
  49. return
  50. }
  51. userQuestion := strings.TrimSpace(requestData.UserQuestion)
  52. if userQuestion == "" {
  53. c.writeSSEJSON(map[string]interface{}{
  54. "type": "online_error",
  55. "message": "问题不能为空",
  56. })
  57. c.writeSSEJSON(map[string]interface{}{"type": "completed"})
  58. return
  59. }
  60. if c.shouldProxyToAIChat() {
  61. if err := c.proxyAIChatSSE("/report/complete-flow", c.Ctx.Input.RequestBody); err == nil {
  62. return
  63. } else {
  64. fmt.Printf("[report-compat] proxy to aichat failed, fallback to local stream: %v\n", err)
  65. }
  66. }
  67. result, err := c.callStreamChatWithDB(requestData)
  68. if err != nil {
  69. c.writeSSEJSON(map[string]interface{}{
  70. "type": "online_error",
  71. "ai_conversation_id": result.AIConversationID,
  72. "ai_message_id": result.AIMessageID,
  73. "message": err.Error(),
  74. })
  75. c.writeSSEJSON(map[string]interface{}{
  76. "type": "completed",
  77. "ai_conversation_id": result.AIConversationID,
  78. "ai_message_id": result.AIMessageID,
  79. })
  80. return
  81. }
  82. c.writeSSEJSON(map[string]interface{}{
  83. "type": "online_answer",
  84. "ai_conversation_id": result.AIConversationID,
  85. "ai_message_id": result.AIMessageID,
  86. "content": result.Content,
  87. })
  88. c.writeSSEJSON(map[string]interface{}{
  89. "type": "completed",
  90. "ai_conversation_id": result.AIConversationID,
  91. "ai_message_id": result.AIMessageID,
  92. })
  93. }
  94. func (c *ReportCompatController) UpdateAIMessage() {
  95. if c.shouldProxyToAIChat() {
  96. if err := c.proxyAIChatJSON("/report/update-ai-message", c.Ctx.Input.RequestBody); err == nil {
  97. return
  98. } else {
  99. fmt.Printf("[report-compat] proxy update-ai-message failed, fallback to local update: %v\n", err)
  100. }
  101. }
  102. var requestData updateAIMessageRequest
  103. if err := json.Unmarshal(c.Ctx.Input.RequestBody, &requestData); err != nil {
  104. c.Data["json"] = map[string]interface{}{
  105. "success": false,
  106. "message": fmt.Sprintf("请求参数解析失败: %s", err.Error()),
  107. }
  108. c.ServeJSON()
  109. return
  110. }
  111. if requestData.AIMessageID == 0 {
  112. c.Data["json"] = map[string]interface{}{
  113. "success": false,
  114. "message": "ai_message_id 不能为空",
  115. }
  116. c.ServeJSON()
  117. return
  118. }
  119. if err := models.DB.Model(&models.AIMessage{}).
  120. Where("id = ? AND is_deleted = ?", requestData.AIMessageID, 0).
  121. Update("content", requestData.Content).Error; err != nil {
  122. c.Data["json"] = map[string]interface{}{
  123. "success": false,
  124. "message": fmt.Sprintf("更新 AI 消息失败: %s", err.Error()),
  125. }
  126. c.ServeJSON()
  127. return
  128. }
  129. c.Data["json"] = map[string]interface{}{
  130. "success": true,
  131. "message": "AI 消息已更新",
  132. }
  133. c.ServeJSON()
  134. }
  135. func (c *ReportCompatController) StopSSE() {
  136. if c.shouldProxyToAIChat() {
  137. if err := c.proxyAIChatJSON("/sse/stop", c.Ctx.Input.RequestBody); err == nil {
  138. return
  139. } else {
  140. fmt.Printf("[report-compat] proxy sse/stop failed, fallback to local success response: %v\n", err)
  141. }
  142. }
  143. var requestData stopSSERequest
  144. _ = json.Unmarshal(c.Ctx.Input.RequestBody, &requestData)
  145. c.Data["json"] = map[string]interface{}{
  146. "success": true,
  147. "message": "已接收停止请求",
  148. "ai_conversation_id": requestData.AIConversationID,
  149. }
  150. c.ServeJSON()
  151. }
  152. func (c *ReportCompatController) setSSEHeaders() {
  153. c.Ctx.ResponseWriter.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
  154. c.Ctx.ResponseWriter.Header().Set("Cache-Control", "no-cache")
  155. c.Ctx.ResponseWriter.Header().Set("Connection", "keep-alive")
  156. c.Ctx.ResponseWriter.Header().Set("Access-Control-Allow-Origin", "*")
  157. c.Ctx.ResponseWriter.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
  158. c.Ctx.ResponseWriter.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Token, token")
  159. }
  160. func (c *ReportCompatController) writeSSEJSON(payload map[string]interface{}) {
  161. responseJSON, _ := json.Marshal(payload)
  162. fmt.Fprintf(c.Ctx.ResponseWriter, "data: %s\n\n", responseJSON)
  163. c.Ctx.ResponseWriter.Flush()
  164. }
  165. func (c *ReportCompatController) shouldProxyToAIChat() bool {
  166. token := c.getRequestToken()
  167. if token == "" {
  168. return false
  169. }
  170. if _, err := utils.VerifyLocalToken(token); err == nil {
  171. return false
  172. }
  173. return true
  174. }
  175. func (c *ReportCompatController) getRequestToken() string {
  176. for _, headerName := range []string{"token", "Token", "Authorization"} {
  177. headerValue := strings.TrimSpace(c.Ctx.Request.Header.Get(headerName))
  178. if headerValue == "" {
  179. continue
  180. }
  181. if headerName == "Authorization" && strings.HasPrefix(headerValue, "Bearer ") {
  182. return strings.TrimPrefix(headerValue, "Bearer ")
  183. }
  184. return headerValue
  185. }
  186. return ""
  187. }
  188. func (c *ReportCompatController) getAIChatBaseURL() string {
  189. baseURL, err := beego.AppConfig.String("aichat_api_url")
  190. if err != nil || strings.TrimSpace(baseURL) == "" {
  191. baseURL = "http://127.0.0.1:28002/api/v1"
  192. }
  193. return strings.TrimRight(baseURL, "/")
  194. }
  195. func (c *ReportCompatController) proxyAIChatSSE(path string, requestBody []byte) error {
  196. upstreamReq, err := http.NewRequest(
  197. http.MethodPost,
  198. c.getAIChatBaseURL()+path,
  199. bytes.NewBuffer(requestBody),
  200. )
  201. if err != nil {
  202. return fmt.Errorf("创建 aichat SSE 请求失败: %w", err)
  203. }
  204. upstreamReq.Header.Set("Content-Type", "application/json")
  205. c.forwardAuthHeaders(upstreamReq)
  206. client := &http.Client{Timeout: 10 * time.Minute}
  207. resp, err := client.Do(upstreamReq)
  208. if err != nil {
  209. return fmt.Errorf("调用 aichat SSE 失败: %w", err)
  210. }
  211. defer resp.Body.Close()
  212. if resp.StatusCode != http.StatusOK {
  213. responseBody, _ := io.ReadAll(resp.Body)
  214. return fmt.Errorf("aichat SSE 返回异常状态: %d %s", resp.StatusCode, strings.TrimSpace(string(responseBody)))
  215. }
  216. buffer := make([]byte, 4096)
  217. for {
  218. n, readErr := resp.Body.Read(buffer)
  219. if n > 0 {
  220. if _, err := c.Ctx.ResponseWriter.Write(buffer[:n]); err != nil {
  221. return fmt.Errorf("写入前端 SSE 响应失败: %w", err)
  222. }
  223. c.Ctx.ResponseWriter.Flush()
  224. }
  225. if readErr == io.EOF {
  226. return nil
  227. }
  228. if readErr != nil {
  229. return fmt.Errorf("读取 aichat SSE 响应失败: %w", readErr)
  230. }
  231. }
  232. }
  233. func (c *ReportCompatController) proxyAIChatJSON(path string, requestBody []byte) error {
  234. upstreamReq, err := http.NewRequest(
  235. http.MethodPost,
  236. c.getAIChatBaseURL()+path,
  237. bytes.NewBuffer(requestBody),
  238. )
  239. if err != nil {
  240. return fmt.Errorf("创建 aichat JSON 请求失败: %w", err)
  241. }
  242. upstreamReq.Header.Set("Content-Type", "application/json")
  243. c.forwardAuthHeaders(upstreamReq)
  244. client := &http.Client{Timeout: 30 * time.Second}
  245. resp, err := client.Do(upstreamReq)
  246. if err != nil {
  247. return fmt.Errorf("调用 aichat JSON 接口失败: %w", err)
  248. }
  249. defer resp.Body.Close()
  250. responseBody, err := io.ReadAll(resp.Body)
  251. if err != nil {
  252. return fmt.Errorf("读取 aichat JSON 响应失败: %w", err)
  253. }
  254. c.Ctx.Output.SetStatus(resp.StatusCode)
  255. c.Ctx.Output.Header("Content-Type", resp.Header.Get("Content-Type"))
  256. _, _ = c.Ctx.ResponseWriter.Write(responseBody)
  257. return nil
  258. }
  259. func (c *ReportCompatController) callStreamChatWithDB(requestData reportCompleteFlowRequest) (streamChatAggregateResult, error) {
  260. upstreamBody := map[string]interface{}{
  261. "message": requestData.UserQuestion,
  262. "ai_conversation_id": requestData.AIConversationID,
  263. "business_type": 0,
  264. }
  265. requestBody, err := json.Marshal(upstreamBody)
  266. if err != nil {
  267. return streamChatAggregateResult{}, fmt.Errorf("构建内部请求失败: %w", err)
  268. }
  269. httpPort, err := beego.AppConfig.Int("httpport")
  270. if err != nil || httpPort == 0 {
  271. httpPort = 22001
  272. }
  273. upstreamURL := fmt.Sprintf("http://127.0.0.1:%d/apiv1/stream/chat-with-db", httpPort)
  274. upstreamReq, err := http.NewRequest(http.MethodPost, upstreamURL, bytes.NewBuffer(requestBody))
  275. if err != nil {
  276. return streamChatAggregateResult{}, fmt.Errorf("创建内部请求失败: %w", err)
  277. }
  278. upstreamReq.Header.Set("Content-Type", "application/json")
  279. c.forwardAuthHeaders(upstreamReq)
  280. client := &http.Client{Timeout: 10 * time.Minute}
  281. resp, err := client.Do(upstreamReq)
  282. if err != nil {
  283. return streamChatAggregateResult{}, fmt.Errorf("调用聊天接口失败: %w", err)
  284. }
  285. defer resp.Body.Close()
  286. result, parseErr := parseStreamChatResponse(resp.Body)
  287. if resp.StatusCode != http.StatusOK {
  288. return result, fmt.Errorf("聊天接口返回异常状态: %d", resp.StatusCode)
  289. }
  290. if parseErr != nil {
  291. return result, parseErr
  292. }
  293. if strings.TrimSpace(result.Content) == "" {
  294. return result, fmt.Errorf("聊天接口未返回有效内容")
  295. }
  296. return result, nil
  297. }
  298. func (c *ReportCompatController) forwardAuthHeaders(req *http.Request) {
  299. for _, headerName := range []string{"Authorization", "Token", "token"} {
  300. if headerValue := strings.TrimSpace(c.Ctx.Request.Header.Get(headerName)); headerValue != "" {
  301. req.Header.Set(headerName, headerValue)
  302. }
  303. }
  304. }
  305. func parseStreamChatResponse(reader io.Reader) (streamChatAggregateResult, error) {
  306. scanner := bufio.NewScanner(reader)
  307. scanner.Buffer(make([]byte, 0, 64*1024), 10*1024*1024)
  308. var result streamChatAggregateResult
  309. var contentBuilder strings.Builder
  310. for scanner.Scan() {
  311. line := strings.TrimRight(scanner.Text(), "\r")
  312. if strings.TrimSpace(line) == "" {
  313. continue
  314. }
  315. if strings.HasPrefix(line, "data: ") {
  316. data := strings.TrimPrefix(line, "data: ")
  317. if data == "[DONE]" {
  318. break
  319. }
  320. var payload map[string]interface{}
  321. if err := json.Unmarshal([]byte(data), &payload); err == nil {
  322. if result.AIConversationID == 0 {
  323. result.AIConversationID = getUint64FromMap(payload, "ai_conversation_id")
  324. }
  325. if result.AIMessageID == 0 {
  326. result.AIMessageID = getUint64FromMap(payload, "ai_message_id")
  327. }
  328. if errorMessage, ok := payload["error"].(string); ok && strings.TrimSpace(errorMessage) != "" {
  329. result.Content = contentBuilder.String()
  330. return result, fmt.Errorf("%s", errorMessage)
  331. }
  332. if content, ok := payload["content"].(string); ok {
  333. contentBuilder.WriteString(strings.ReplaceAll(content, "\\n", "\n"))
  334. }
  335. continue
  336. }
  337. if errorMessage, ok := extractRawErrorMessage(data); ok {
  338. result.Content = contentBuilder.String()
  339. return result, fmt.Errorf("%s", errorMessage)
  340. }
  341. contentBuilder.WriteString(strings.ReplaceAll(data, "\\n", "\n"))
  342. continue
  343. }
  344. contentBuilder.WriteString(strings.ReplaceAll(line, "\\n", "\n"))
  345. }
  346. if err := scanner.Err(); err != nil {
  347. result.Content = contentBuilder.String()
  348. return result, fmt.Errorf("读取聊天流失败: %w", err)
  349. }
  350. result.Content = contentBuilder.String()
  351. return result, nil
  352. }
  353. func extractRawErrorMessage(data string) (string, bool) {
  354. if !strings.HasPrefix(data, "{\"error\":") {
  355. return "", false
  356. }
  357. errorMessage := strings.TrimPrefix(data, "{\"error\":")
  358. errorMessage = strings.TrimSuffix(errorMessage, "}")
  359. errorMessage = strings.TrimSpace(errorMessage)
  360. errorMessage = strings.Trim(errorMessage, "\"")
  361. if errorMessage == "" {
  362. return "", false
  363. }
  364. return errorMessage, true
  365. }
  366. func getUint64FromMap(data map[string]interface{}, key string) uint64 {
  367. rawValue, ok := data[key]
  368. if !ok {
  369. return 0
  370. }
  371. switch value := rawValue.(type) {
  372. case float64:
  373. return uint64(value)
  374. case int:
  375. return uint64(value)
  376. case int64:
  377. return uint64(value)
  378. case uint64:
  379. return value
  380. default:
  381. return 0
  382. }
  383. }