|
|
@@ -0,0 +1,448 @@
|
|
|
+package controllers
|
|
|
+
|
|
|
+import (
|
|
|
+ "bufio"
|
|
|
+ "bytes"
|
|
|
+ "encoding/json"
|
|
|
+ "fmt"
|
|
|
+ "io"
|
|
|
+ "net/http"
|
|
|
+ "strings"
|
|
|
+ "time"
|
|
|
+
|
|
|
+ "shudao-chat-go/models"
|
|
|
+ "shudao-chat-go/utils"
|
|
|
+
|
|
|
+ beego "github.com/beego/beego/v2/server/web"
|
|
|
+)
|
|
|
+
|
|
|
+// ReportCompatController keeps the old AI chat frontend contract working
|
|
|
+// while routing the request to the service that actually implements it.
|
|
|
+type ReportCompatController struct {
|
|
|
+ beego.Controller
|
|
|
+}
|
|
|
+
|
|
|
+type reportCompleteFlowRequest struct {
|
|
|
+ UserQuestion string `json:"user_question"`
|
|
|
+ WindowSize int `json:"window_size"`
|
|
|
+ NResults int `json:"n_results"`
|
|
|
+ AIConversationID uint64 `json:"ai_conversation_id"`
|
|
|
+ IsNetworkSearchEnabled bool `json:"is_network_search_enabled"`
|
|
|
+ EnableOnlineModel bool `json:"enable_online_model"`
|
|
|
+}
|
|
|
+
|
|
|
+type updateAIMessageRequest struct {
|
|
|
+ AIMessageID uint64 `json:"ai_message_id"`
|
|
|
+ Content string `json:"content"`
|
|
|
+}
|
|
|
+
|
|
|
+type stopSSERequest struct {
|
|
|
+ AIConversationID uint64 `json:"ai_conversation_id"`
|
|
|
+}
|
|
|
+
|
|
|
+type streamChatAggregateResult struct {
|
|
|
+ AIConversationID uint64
|
|
|
+ AIMessageID uint64
|
|
|
+ Content string
|
|
|
+}
|
|
|
+
|
|
|
+func (c *ReportCompatController) CompleteFlow() {
|
|
|
+ c.setSSEHeaders()
|
|
|
+
|
|
|
+ var requestData reportCompleteFlowRequest
|
|
|
+ if err := json.Unmarshal(c.Ctx.Input.RequestBody, &requestData); err != nil {
|
|
|
+ c.writeSSEJSON(map[string]interface{}{
|
|
|
+ "type": "online_error",
|
|
|
+ "message": fmt.Sprintf("请求参数解析失败: %s", err.Error()),
|
|
|
+ })
|
|
|
+ c.writeSSEJSON(map[string]interface{}{"type": "completed"})
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ userQuestion := strings.TrimSpace(requestData.UserQuestion)
|
|
|
+ if userQuestion == "" {
|
|
|
+ c.writeSSEJSON(map[string]interface{}{
|
|
|
+ "type": "online_error",
|
|
|
+ "message": "问题不能为空",
|
|
|
+ })
|
|
|
+ c.writeSSEJSON(map[string]interface{}{"type": "completed"})
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if c.shouldProxyToAIChat() {
|
|
|
+ if err := c.proxyAIChatSSE("/report/complete-flow", c.Ctx.Input.RequestBody); err == nil {
|
|
|
+ return
|
|
|
+ } else {
|
|
|
+ fmt.Printf("[report-compat] proxy to aichat failed, fallback to local stream: %v\n", err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ result, err := c.callStreamChatWithDB(requestData)
|
|
|
+ if err != nil {
|
|
|
+ c.writeSSEJSON(map[string]interface{}{
|
|
|
+ "type": "online_error",
|
|
|
+ "ai_conversation_id": result.AIConversationID,
|
|
|
+ "ai_message_id": result.AIMessageID,
|
|
|
+ "message": err.Error(),
|
|
|
+ })
|
|
|
+ c.writeSSEJSON(map[string]interface{}{
|
|
|
+ "type": "completed",
|
|
|
+ "ai_conversation_id": result.AIConversationID,
|
|
|
+ "ai_message_id": result.AIMessageID,
|
|
|
+ })
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ c.writeSSEJSON(map[string]interface{}{
|
|
|
+ "type": "online_answer",
|
|
|
+ "ai_conversation_id": result.AIConversationID,
|
|
|
+ "ai_message_id": result.AIMessageID,
|
|
|
+ "content": result.Content,
|
|
|
+ })
|
|
|
+ c.writeSSEJSON(map[string]interface{}{
|
|
|
+ "type": "completed",
|
|
|
+ "ai_conversation_id": result.AIConversationID,
|
|
|
+ "ai_message_id": result.AIMessageID,
|
|
|
+ })
|
|
|
+}
|
|
|
+
|
|
|
+func (c *ReportCompatController) UpdateAIMessage() {
|
|
|
+ if c.shouldProxyToAIChat() {
|
|
|
+ if err := c.proxyAIChatJSON("/report/update-ai-message", c.Ctx.Input.RequestBody); err == nil {
|
|
|
+ return
|
|
|
+ } else {
|
|
|
+ fmt.Printf("[report-compat] proxy update-ai-message failed, fallback to local update: %v\n", err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ var requestData updateAIMessageRequest
|
|
|
+ if err := json.Unmarshal(c.Ctx.Input.RequestBody, &requestData); err != nil {
|
|
|
+ c.Data["json"] = map[string]interface{}{
|
|
|
+ "success": false,
|
|
|
+ "message": fmt.Sprintf("请求参数解析失败: %s", err.Error()),
|
|
|
+ }
|
|
|
+ c.ServeJSON()
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if requestData.AIMessageID == 0 {
|
|
|
+ c.Data["json"] = map[string]interface{}{
|
|
|
+ "success": false,
|
|
|
+ "message": "ai_message_id 不能为空",
|
|
|
+ }
|
|
|
+ c.ServeJSON()
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ if err := models.DB.Model(&models.AIMessage{}).
|
|
|
+ Where("id = ? AND is_deleted = ?", requestData.AIMessageID, 0).
|
|
|
+ Update("content", requestData.Content).Error; err != nil {
|
|
|
+ c.Data["json"] = map[string]interface{}{
|
|
|
+ "success": false,
|
|
|
+ "message": fmt.Sprintf("更新 AI 消息失败: %s", err.Error()),
|
|
|
+ }
|
|
|
+ c.ServeJSON()
|
|
|
+ return
|
|
|
+ }
|
|
|
+
|
|
|
+ c.Data["json"] = map[string]interface{}{
|
|
|
+ "success": true,
|
|
|
+ "message": "AI 消息已更新",
|
|
|
+ }
|
|
|
+ c.ServeJSON()
|
|
|
+}
|
|
|
+
|
|
|
+func (c *ReportCompatController) StopSSE() {
|
|
|
+ if c.shouldProxyToAIChat() {
|
|
|
+ if err := c.proxyAIChatJSON("/sse/stop", c.Ctx.Input.RequestBody); err == nil {
|
|
|
+ return
|
|
|
+ } else {
|
|
|
+ fmt.Printf("[report-compat] proxy sse/stop failed, fallback to local success response: %v\n", err)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ var requestData stopSSERequest
|
|
|
+ _ = json.Unmarshal(c.Ctx.Input.RequestBody, &requestData)
|
|
|
+
|
|
|
+ c.Data["json"] = map[string]interface{}{
|
|
|
+ "success": true,
|
|
|
+ "message": "已接收停止请求",
|
|
|
+ "ai_conversation_id": requestData.AIConversationID,
|
|
|
+ }
|
|
|
+ c.ServeJSON()
|
|
|
+}
|
|
|
+
|
|
|
+func (c *ReportCompatController) setSSEHeaders() {
|
|
|
+ c.Ctx.ResponseWriter.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
|
|
|
+ c.Ctx.ResponseWriter.Header().Set("Cache-Control", "no-cache")
|
|
|
+ c.Ctx.ResponseWriter.Header().Set("Connection", "keep-alive")
|
|
|
+ c.Ctx.ResponseWriter.Header().Set("Access-Control-Allow-Origin", "*")
|
|
|
+ c.Ctx.ResponseWriter.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
|
|
|
+ c.Ctx.ResponseWriter.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Token, token")
|
|
|
+}
|
|
|
+
|
|
|
+func (c *ReportCompatController) writeSSEJSON(payload map[string]interface{}) {
|
|
|
+ responseJSON, _ := json.Marshal(payload)
|
|
|
+ fmt.Fprintf(c.Ctx.ResponseWriter, "data: %s\n\n", responseJSON)
|
|
|
+ c.Ctx.ResponseWriter.Flush()
|
|
|
+}
|
|
|
+
|
|
|
+func (c *ReportCompatController) shouldProxyToAIChat() bool {
|
|
|
+ token := c.getRequestToken()
|
|
|
+ if token == "" {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ if _, err := utils.VerifyLocalToken(token); err == nil {
|
|
|
+ return false
|
|
|
+ }
|
|
|
+
|
|
|
+ return true
|
|
|
+}
|
|
|
+
|
|
|
+func (c *ReportCompatController) getRequestToken() string {
|
|
|
+ for _, headerName := range []string{"token", "Token", "Authorization"} {
|
|
|
+ headerValue := strings.TrimSpace(c.Ctx.Request.Header.Get(headerName))
|
|
|
+ if headerValue == "" {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+ if headerName == "Authorization" && strings.HasPrefix(headerValue, "Bearer ") {
|
|
|
+ return strings.TrimPrefix(headerValue, "Bearer ")
|
|
|
+ }
|
|
|
+ return headerValue
|
|
|
+ }
|
|
|
+
|
|
|
+ return ""
|
|
|
+}
|
|
|
+
|
|
|
+func (c *ReportCompatController) getAIChatBaseURL() string {
|
|
|
+ baseURL, err := beego.AppConfig.String("aichat_api_url")
|
|
|
+ if err != nil || strings.TrimSpace(baseURL) == "" {
|
|
|
+ baseURL = "http://127.0.0.1:28002/api/v1"
|
|
|
+ }
|
|
|
+
|
|
|
+ return strings.TrimRight(baseURL, "/")
|
|
|
+}
|
|
|
+
|
|
|
+func (c *ReportCompatController) proxyAIChatSSE(path string, requestBody []byte) error {
|
|
|
+ upstreamReq, err := http.NewRequest(
|
|
|
+ http.MethodPost,
|
|
|
+ c.getAIChatBaseURL()+path,
|
|
|
+ bytes.NewBuffer(requestBody),
|
|
|
+ )
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("创建 aichat SSE 请求失败: %w", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ upstreamReq.Header.Set("Content-Type", "application/json")
|
|
|
+ c.forwardAuthHeaders(upstreamReq)
|
|
|
+
|
|
|
+ client := &http.Client{Timeout: 10 * time.Minute}
|
|
|
+ resp, err := client.Do(upstreamReq)
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("调用 aichat SSE 失败: %w", err)
|
|
|
+ }
|
|
|
+ defer resp.Body.Close()
|
|
|
+
|
|
|
+ if resp.StatusCode != http.StatusOK {
|
|
|
+ responseBody, _ := io.ReadAll(resp.Body)
|
|
|
+ return fmt.Errorf("aichat SSE 返回异常状态: %d %s", resp.StatusCode, strings.TrimSpace(string(responseBody)))
|
|
|
+ }
|
|
|
+
|
|
|
+ buffer := make([]byte, 4096)
|
|
|
+ for {
|
|
|
+ n, readErr := resp.Body.Read(buffer)
|
|
|
+ if n > 0 {
|
|
|
+ if _, err := c.Ctx.ResponseWriter.Write(buffer[:n]); err != nil {
|
|
|
+ return fmt.Errorf("写入前端 SSE 响应失败: %w", err)
|
|
|
+ }
|
|
|
+ c.Ctx.ResponseWriter.Flush()
|
|
|
+ }
|
|
|
+
|
|
|
+ if readErr == io.EOF {
|
|
|
+ return nil
|
|
|
+ }
|
|
|
+ if readErr != nil {
|
|
|
+ return fmt.Errorf("读取 aichat SSE 响应失败: %w", readErr)
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func (c *ReportCompatController) proxyAIChatJSON(path string, requestBody []byte) error {
|
|
|
+ upstreamReq, err := http.NewRequest(
|
|
|
+ http.MethodPost,
|
|
|
+ c.getAIChatBaseURL()+path,
|
|
|
+ bytes.NewBuffer(requestBody),
|
|
|
+ )
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("创建 aichat JSON 请求失败: %w", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ upstreamReq.Header.Set("Content-Type", "application/json")
|
|
|
+ c.forwardAuthHeaders(upstreamReq)
|
|
|
+
|
|
|
+ client := &http.Client{Timeout: 30 * time.Second}
|
|
|
+ resp, err := client.Do(upstreamReq)
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("调用 aichat JSON 接口失败: %w", err)
|
|
|
+ }
|
|
|
+ defer resp.Body.Close()
|
|
|
+
|
|
|
+ responseBody, err := io.ReadAll(resp.Body)
|
|
|
+ if err != nil {
|
|
|
+ return fmt.Errorf("读取 aichat JSON 响应失败: %w", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ c.Ctx.Output.SetStatus(resp.StatusCode)
|
|
|
+ c.Ctx.Output.Header("Content-Type", resp.Header.Get("Content-Type"))
|
|
|
+ _, _ = c.Ctx.ResponseWriter.Write(responseBody)
|
|
|
+ return nil
|
|
|
+}
|
|
|
+
|
|
|
+func (c *ReportCompatController) callStreamChatWithDB(requestData reportCompleteFlowRequest) (streamChatAggregateResult, error) {
|
|
|
+ upstreamBody := map[string]interface{}{
|
|
|
+ "message": requestData.UserQuestion,
|
|
|
+ "ai_conversation_id": requestData.AIConversationID,
|
|
|
+ "business_type": 0,
|
|
|
+ }
|
|
|
+
|
|
|
+ requestBody, err := json.Marshal(upstreamBody)
|
|
|
+ if err != nil {
|
|
|
+ return streamChatAggregateResult{}, fmt.Errorf("构建内部请求失败: %w", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ httpPort, err := beego.AppConfig.Int("httpport")
|
|
|
+ if err != nil || httpPort == 0 {
|
|
|
+ httpPort = 22001
|
|
|
+ }
|
|
|
+ upstreamURL := fmt.Sprintf("http://127.0.0.1:%d/apiv1/stream/chat-with-db", httpPort)
|
|
|
+
|
|
|
+ upstreamReq, err := http.NewRequest(http.MethodPost, upstreamURL, bytes.NewBuffer(requestBody))
|
|
|
+ if err != nil {
|
|
|
+ return streamChatAggregateResult{}, fmt.Errorf("创建内部请求失败: %w", err)
|
|
|
+ }
|
|
|
+ upstreamReq.Header.Set("Content-Type", "application/json")
|
|
|
+ c.forwardAuthHeaders(upstreamReq)
|
|
|
+
|
|
|
+ client := &http.Client{Timeout: 10 * time.Minute}
|
|
|
+ resp, err := client.Do(upstreamReq)
|
|
|
+ if err != nil {
|
|
|
+ return streamChatAggregateResult{}, fmt.Errorf("调用聊天接口失败: %w", err)
|
|
|
+ }
|
|
|
+ defer resp.Body.Close()
|
|
|
+
|
|
|
+ result, parseErr := parseStreamChatResponse(resp.Body)
|
|
|
+ if resp.StatusCode != http.StatusOK {
|
|
|
+ return result, fmt.Errorf("聊天接口返回异常状态: %d", resp.StatusCode)
|
|
|
+ }
|
|
|
+ if parseErr != nil {
|
|
|
+ return result, parseErr
|
|
|
+ }
|
|
|
+ if strings.TrimSpace(result.Content) == "" {
|
|
|
+ return result, fmt.Errorf("聊天接口未返回有效内容")
|
|
|
+ }
|
|
|
+
|
|
|
+ return result, nil
|
|
|
+}
|
|
|
+
|
|
|
+func (c *ReportCompatController) forwardAuthHeaders(req *http.Request) {
|
|
|
+ for _, headerName := range []string{"Authorization", "Token", "token"} {
|
|
|
+ if headerValue := strings.TrimSpace(c.Ctx.Request.Header.Get(headerName)); headerValue != "" {
|
|
|
+ req.Header.Set(headerName, headerValue)
|
|
|
+ }
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+func parseStreamChatResponse(reader io.Reader) (streamChatAggregateResult, error) {
|
|
|
+ scanner := bufio.NewScanner(reader)
|
|
|
+ scanner.Buffer(make([]byte, 0, 64*1024), 10*1024*1024)
|
|
|
+
|
|
|
+ var result streamChatAggregateResult
|
|
|
+ var contentBuilder strings.Builder
|
|
|
+
|
|
|
+ for scanner.Scan() {
|
|
|
+ line := strings.TrimRight(scanner.Text(), "\r")
|
|
|
+ if strings.TrimSpace(line) == "" {
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ if strings.HasPrefix(line, "data: ") {
|
|
|
+ data := strings.TrimPrefix(line, "data: ")
|
|
|
+ if data == "[DONE]" {
|
|
|
+ break
|
|
|
+ }
|
|
|
+
|
|
|
+ var payload map[string]interface{}
|
|
|
+ if err := json.Unmarshal([]byte(data), &payload); err == nil {
|
|
|
+ if result.AIConversationID == 0 {
|
|
|
+ result.AIConversationID = getUint64FromMap(payload, "ai_conversation_id")
|
|
|
+ }
|
|
|
+ if result.AIMessageID == 0 {
|
|
|
+ result.AIMessageID = getUint64FromMap(payload, "ai_message_id")
|
|
|
+ }
|
|
|
+ if errorMessage, ok := payload["error"].(string); ok && strings.TrimSpace(errorMessage) != "" {
|
|
|
+ result.Content = contentBuilder.String()
|
|
|
+ return result, fmt.Errorf("%s", errorMessage)
|
|
|
+ }
|
|
|
+ if content, ok := payload["content"].(string); ok {
|
|
|
+ contentBuilder.WriteString(strings.ReplaceAll(content, "\\n", "\n"))
|
|
|
+ }
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ if errorMessage, ok := extractRawErrorMessage(data); ok {
|
|
|
+ result.Content = contentBuilder.String()
|
|
|
+ return result, fmt.Errorf("%s", errorMessage)
|
|
|
+ }
|
|
|
+
|
|
|
+ contentBuilder.WriteString(strings.ReplaceAll(data, "\\n", "\n"))
|
|
|
+ continue
|
|
|
+ }
|
|
|
+
|
|
|
+ contentBuilder.WriteString(strings.ReplaceAll(line, "\\n", "\n"))
|
|
|
+ }
|
|
|
+
|
|
|
+ if err := scanner.Err(); err != nil {
|
|
|
+ result.Content = contentBuilder.String()
|
|
|
+ return result, fmt.Errorf("读取聊天流失败: %w", err)
|
|
|
+ }
|
|
|
+
|
|
|
+ result.Content = contentBuilder.String()
|
|
|
+ return result, nil
|
|
|
+}
|
|
|
+
|
|
|
+func extractRawErrorMessage(data string) (string, bool) {
|
|
|
+ if !strings.HasPrefix(data, "{\"error\":") {
|
|
|
+ return "", false
|
|
|
+ }
|
|
|
+
|
|
|
+ errorMessage := strings.TrimPrefix(data, "{\"error\":")
|
|
|
+ errorMessage = strings.TrimSuffix(errorMessage, "}")
|
|
|
+ errorMessage = strings.TrimSpace(errorMessage)
|
|
|
+ errorMessage = strings.Trim(errorMessage, "\"")
|
|
|
+ if errorMessage == "" {
|
|
|
+ return "", false
|
|
|
+ }
|
|
|
+
|
|
|
+ return errorMessage, true
|
|
|
+}
|
|
|
+
|
|
|
+func getUint64FromMap(data map[string]interface{}, key string) uint64 {
|
|
|
+ rawValue, ok := data[key]
|
|
|
+ if !ok {
|
|
|
+ return 0
|
|
|
+ }
|
|
|
+
|
|
|
+ switch value := rawValue.(type) {
|
|
|
+ case float64:
|
|
|
+ return uint64(value)
|
|
|
+ case int:
|
|
|
+ return uint64(value)
|
|
|
+ case int64:
|
|
|
+ return uint64(value)
|
|
|
+ case uint64:
|
|
|
+ return value
|
|
|
+ default:
|
|
|
+ return 0
|
|
|
+ }
|
|
|
+}
|