package controllers import ( "bufio" "bytes" "encoding/json" "fmt" "io" "net/http" "strings" "time" "shudao-chat-go/models" "shudao-chat-go/utils" beego "github.com/beego/beego/v2/server/web" ) // ReportCompatController keeps the old AI chat frontend contract working // while routing the request to the service that actually implements it. type ReportCompatController struct { beego.Controller } type reportCompleteFlowRequest struct { UserQuestion string `json:"user_question"` WindowSize int `json:"window_size"` NResults int `json:"n_results"` AIConversationID uint64 `json:"ai_conversation_id"` IsNetworkSearchEnabled bool `json:"is_network_search_enabled"` EnableOnlineModel bool `json:"enable_online_model"` } type updateAIMessageRequest struct { AIMessageID uint64 `json:"ai_message_id"` Content string `json:"content"` } type stopSSERequest struct { AIConversationID uint64 `json:"ai_conversation_id"` } type streamChatAggregateResult struct { AIConversationID uint64 AIMessageID uint64 Content string } func (c *ReportCompatController) CompleteFlow() { c.setSSEHeaders() var requestData reportCompleteFlowRequest if err := json.Unmarshal(c.Ctx.Input.RequestBody, &requestData); err != nil { c.writeSSEJSON(map[string]interface{}{ "type": "online_error", "message": fmt.Sprintf("请求参数解析失败: %s", err.Error()), }) c.writeSSEJSON(map[string]interface{}{"type": "completed"}) return } userQuestion := strings.TrimSpace(requestData.UserQuestion) if userQuestion == "" { c.writeSSEJSON(map[string]interface{}{ "type": "online_error", "message": "问题不能为空", }) c.writeSSEJSON(map[string]interface{}{"type": "completed"}) return } if c.shouldProxyToAIChat() { if err := c.proxyAIChatSSE("/report/complete-flow", c.Ctx.Input.RequestBody); err == nil { return } else { fmt.Printf("[report-compat] proxy to aichat failed, fallback to local stream: %v\n", err) } } result, err := c.callStreamChatWithDB(requestData) if err != nil { c.writeSSEJSON(map[string]interface{}{ "type": "online_error", "ai_conversation_id": result.AIConversationID, "ai_message_id": result.AIMessageID, "message": err.Error(), }) c.writeSSEJSON(map[string]interface{}{ "type": "completed", "ai_conversation_id": result.AIConversationID, "ai_message_id": result.AIMessageID, }) return } c.writeSSEJSON(map[string]interface{}{ "type": "online_answer", "ai_conversation_id": result.AIConversationID, "ai_message_id": result.AIMessageID, "content": result.Content, }) c.writeSSEJSON(map[string]interface{}{ "type": "completed", "ai_conversation_id": result.AIConversationID, "ai_message_id": result.AIMessageID, }) } func (c *ReportCompatController) UpdateAIMessage() { if c.shouldProxyToAIChat() { if err := c.proxyAIChatJSON("/report/update-ai-message", c.Ctx.Input.RequestBody); err == nil { return } else { fmt.Printf("[report-compat] proxy update-ai-message failed, fallback to local update: %v\n", err) } } var requestData updateAIMessageRequest if err := json.Unmarshal(c.Ctx.Input.RequestBody, &requestData); err != nil { c.Data["json"] = map[string]interface{}{ "success": false, "message": fmt.Sprintf("请求参数解析失败: %s", err.Error()), } c.ServeJSON() return } if requestData.AIMessageID == 0 { c.Data["json"] = map[string]interface{}{ "success": false, "message": "ai_message_id 不能为空", } c.ServeJSON() return } if err := models.DB.Model(&models.AIMessage{}). Where("id = ? AND is_deleted = ?", requestData.AIMessageID, 0). Update("content", requestData.Content).Error; err != nil { c.Data["json"] = map[string]interface{}{ "success": false, "message": fmt.Sprintf("更新 AI 消息失败: %s", err.Error()), } c.ServeJSON() return } c.Data["json"] = map[string]interface{}{ "success": true, "message": "AI 消息已更新", } c.ServeJSON() } func (c *ReportCompatController) StopSSE() { if c.shouldProxyToAIChat() { if err := c.proxyAIChatJSON("/sse/stop", c.Ctx.Input.RequestBody); err == nil { return } else { fmt.Printf("[report-compat] proxy sse/stop failed, fallback to local success response: %v\n", err) } } var requestData stopSSERequest _ = json.Unmarshal(c.Ctx.Input.RequestBody, &requestData) c.Data["json"] = map[string]interface{}{ "success": true, "message": "已接收停止请求", "ai_conversation_id": requestData.AIConversationID, } c.ServeJSON() } func (c *ReportCompatController) setSSEHeaders() { c.Ctx.ResponseWriter.Header().Set("Content-Type", "text/event-stream; charset=utf-8") c.Ctx.ResponseWriter.Header().Set("Cache-Control", "no-cache") c.Ctx.ResponseWriter.Header().Set("Connection", "keep-alive") c.Ctx.ResponseWriter.Header().Set("Access-Control-Allow-Origin", "*") c.Ctx.ResponseWriter.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") c.Ctx.ResponseWriter.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Token, token") } func (c *ReportCompatController) writeSSEJSON(payload map[string]interface{}) { responseJSON, _ := json.Marshal(payload) fmt.Fprintf(c.Ctx.ResponseWriter, "data: %s\n\n", responseJSON) c.Ctx.ResponseWriter.Flush() } func (c *ReportCompatController) shouldProxyToAIChat() bool { token := c.getRequestToken() if token == "" { return false } if _, err := utils.VerifyLocalToken(token); err == nil { return false } return true } func (c *ReportCompatController) getRequestToken() string { for _, headerName := range []string{"token", "Token", "Authorization"} { headerValue := strings.TrimSpace(c.Ctx.Request.Header.Get(headerName)) if headerValue == "" { continue } if headerName == "Authorization" && strings.HasPrefix(headerValue, "Bearer ") { return strings.TrimPrefix(headerValue, "Bearer ") } return headerValue } return "" } func (c *ReportCompatController) getAIChatBaseURL() string { baseURL, err := beego.AppConfig.String("aichat_api_url") if err != nil || strings.TrimSpace(baseURL) == "" { baseURL = "http://127.0.0.1:28002/api/v1" } return strings.TrimRight(baseURL, "/") } func (c *ReportCompatController) proxyAIChatSSE(path string, requestBody []byte) error { upstreamReq, err := http.NewRequest( http.MethodPost, c.getAIChatBaseURL()+path, bytes.NewBuffer(requestBody), ) if err != nil { return fmt.Errorf("创建 aichat SSE 请求失败: %w", err) } upstreamReq.Header.Set("Content-Type", "application/json") c.forwardAuthHeaders(upstreamReq) client := &http.Client{Timeout: 10 * time.Minute} resp, err := client.Do(upstreamReq) if err != nil { return fmt.Errorf("调用 aichat SSE 失败: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { responseBody, _ := io.ReadAll(resp.Body) return fmt.Errorf("aichat SSE 返回异常状态: %d %s", resp.StatusCode, strings.TrimSpace(string(responseBody))) } buffer := make([]byte, 4096) for { n, readErr := resp.Body.Read(buffer) if n > 0 { if _, err := c.Ctx.ResponseWriter.Write(buffer[:n]); err != nil { return fmt.Errorf("写入前端 SSE 响应失败: %w", err) } c.Ctx.ResponseWriter.Flush() } if readErr == io.EOF { return nil } if readErr != nil { return fmt.Errorf("读取 aichat SSE 响应失败: %w", readErr) } } } func (c *ReportCompatController) proxyAIChatJSON(path string, requestBody []byte) error { upstreamReq, err := http.NewRequest( http.MethodPost, c.getAIChatBaseURL()+path, bytes.NewBuffer(requestBody), ) if err != nil { return fmt.Errorf("创建 aichat JSON 请求失败: %w", err) } upstreamReq.Header.Set("Content-Type", "application/json") c.forwardAuthHeaders(upstreamReq) client := &http.Client{Timeout: 30 * time.Second} resp, err := client.Do(upstreamReq) if err != nil { return fmt.Errorf("调用 aichat JSON 接口失败: %w", err) } defer resp.Body.Close() responseBody, err := io.ReadAll(resp.Body) if err != nil { return fmt.Errorf("读取 aichat JSON 响应失败: %w", err) } c.Ctx.Output.SetStatus(resp.StatusCode) c.Ctx.Output.Header("Content-Type", resp.Header.Get("Content-Type")) _, _ = c.Ctx.ResponseWriter.Write(responseBody) return nil } func (c *ReportCompatController) callStreamChatWithDB(requestData reportCompleteFlowRequest) (streamChatAggregateResult, error) { upstreamBody := map[string]interface{}{ "message": requestData.UserQuestion, "ai_conversation_id": requestData.AIConversationID, "business_type": 0, } requestBody, err := json.Marshal(upstreamBody) if err != nil { return streamChatAggregateResult{}, fmt.Errorf("构建内部请求失败: %w", err) } httpPort, err := beego.AppConfig.Int("httpport") if err != nil || httpPort == 0 { httpPort = 22001 } upstreamURL := fmt.Sprintf("http://127.0.0.1:%d/apiv1/stream/chat-with-db", httpPort) upstreamReq, err := http.NewRequest(http.MethodPost, upstreamURL, bytes.NewBuffer(requestBody)) if err != nil { return streamChatAggregateResult{}, fmt.Errorf("创建内部请求失败: %w", err) } upstreamReq.Header.Set("Content-Type", "application/json") c.forwardAuthHeaders(upstreamReq) client := &http.Client{Timeout: 10 * time.Minute} resp, err := client.Do(upstreamReq) if err != nil { return streamChatAggregateResult{}, fmt.Errorf("调用聊天接口失败: %w", err) } defer resp.Body.Close() result, parseErr := parseStreamChatResponse(resp.Body) if resp.StatusCode != http.StatusOK { return result, fmt.Errorf("聊天接口返回异常状态: %d", resp.StatusCode) } if parseErr != nil { return result, parseErr } if strings.TrimSpace(result.Content) == "" { return result, fmt.Errorf("聊天接口未返回有效内容") } return result, nil } func (c *ReportCompatController) forwardAuthHeaders(req *http.Request) { for _, headerName := range []string{"Authorization", "Token", "token"} { if headerValue := strings.TrimSpace(c.Ctx.Request.Header.Get(headerName)); headerValue != "" { req.Header.Set(headerName, headerValue) } } } func parseStreamChatResponse(reader io.Reader) (streamChatAggregateResult, error) { scanner := bufio.NewScanner(reader) scanner.Buffer(make([]byte, 0, 64*1024), 10*1024*1024) var result streamChatAggregateResult var contentBuilder strings.Builder for scanner.Scan() { line := strings.TrimRight(scanner.Text(), "\r") if strings.TrimSpace(line) == "" { continue } if strings.HasPrefix(line, "data: ") { data := strings.TrimPrefix(line, "data: ") if data == "[DONE]" { break } var payload map[string]interface{} if err := json.Unmarshal([]byte(data), &payload); err == nil { if result.AIConversationID == 0 { result.AIConversationID = getUint64FromMap(payload, "ai_conversation_id") } if result.AIMessageID == 0 { result.AIMessageID = getUint64FromMap(payload, "ai_message_id") } if errorMessage, ok := payload["error"].(string); ok && strings.TrimSpace(errorMessage) != "" { result.Content = contentBuilder.String() return result, fmt.Errorf("%s", errorMessage) } if content, ok := payload["content"].(string); ok { contentBuilder.WriteString(strings.ReplaceAll(content, "\\n", "\n")) } continue } if errorMessage, ok := extractRawErrorMessage(data); ok { result.Content = contentBuilder.String() return result, fmt.Errorf("%s", errorMessage) } contentBuilder.WriteString(strings.ReplaceAll(data, "\\n", "\n")) continue } contentBuilder.WriteString(strings.ReplaceAll(line, "\\n", "\n")) } if err := scanner.Err(); err != nil { result.Content = contentBuilder.String() return result, fmt.Errorf("读取聊天流失败: %w", err) } result.Content = contentBuilder.String() return result, nil } func extractRawErrorMessage(data string) (string, bool) { if !strings.HasPrefix(data, "{\"error\":") { return "", false } errorMessage := strings.TrimPrefix(data, "{\"error\":") errorMessage = strings.TrimSuffix(errorMessage, "}") errorMessage = strings.TrimSpace(errorMessage) errorMessage = strings.Trim(errorMessage, "\"") if errorMessage == "" { return "", false } return errorMessage, true } func getUint64FromMap(data map[string]interface{}, key string) uint64 { rawValue, ok := data[key] if !ok { return 0 } switch value := rawValue.(type) { case float64: return uint64(value) case int: return uint64(value) case int64: return uint64(value) case uint64: return value default: return 0 } }