Prechádzať zdrojové kódy

优化AI问答的回复

zkn 1 týždeň pred
rodič
commit
93a158ed16

+ 2 - 0
shudao-go-backend/conf/app.conf

@@ -28,6 +28,8 @@ heartbeat_api_url = http://127.0.0.1:24000/api/health
 
 # Token验证API (认证网关) - 本地环境
 auth_api_url = http://127.0.0.1:28004/api/auth/verify
+# AI问答服务代理地址
+aichat_api_url = http://127.0.0.1:28002/api/v1
 
 # ==================== OSS存储配置 ====================
 oss_access_key_id = fnyfi2f368pbic74d8ll

+ 2 - 0
shudao-go-backend/conf/app.conf.prod

@@ -28,6 +28,8 @@ heartbeat_api_url = http://127.0.0.1:24000/api/health
 
 # Token验证API (认证网关) - 生产环境
 auth_api_url = http://127.0.0.1:28004/api/auth/verify
+# AI问答服务代理地址
+aichat_api_url = http://127.0.0.1:28002/api/v1
 
 # ==================== OSS存储配置 ====================
 oss_access_key_id = fnyfi2f368pbic74d8ll

+ 2 - 0
shudao-go-backend/conf/app.conf.test

@@ -28,6 +28,8 @@ heartbeat_api_url = http://127.0.0.1:24000/api/health
 
 # Token验证API (认证网关) - 测试环境
 auth_api_url = http://127.0.0.1:28004/api/auth/verify
+# AI问答服务代理地址
+aichat_api_url = http://127.0.0.1:28002/api/v1
 
 # ==================== OSS存储配置 ====================
 oss_access_key_id = fnyfi2f368pbic74d8ll

+ 4 - 1
shudao-go-backend/controllers/chat.go

@@ -1,4 +1,4 @@
-// Package controllers - chat.go
+// Package controllers - chat.go
 //
 // ⚠️ DEPRECATED NOTICE (弃用说明)
 // ================================================================================
@@ -2683,6 +2683,9 @@ func (c *ChatController) GuessYouWant() {
 	// 构建带有专业问题判断规则的提示词
 	promptWithRules := fmt.Sprintf(`你是蜀道安全管理AI智能助手,请根据用户的问题生成3个相关的后续问题建议(猜你想问)。
 
+## 用户问题
+%s
+
 ## 生成问题规则(最高优先级)
 1. 严禁生成任何政治敏感信息,包含重要国家领导人,重要国际事件等
 2. 严禁在生成的问题中包含人名信息,任何人名都不行

+ 448 - 0
shudao-go-backend/controllers/report_compat.go

@@ -0,0 +1,448 @@
+package controllers
+
+import (
+	"bufio"
+	"bytes"
+	"encoding/json"
+	"fmt"
+	"io"
+	"net/http"
+	"strings"
+	"time"
+
+	"shudao-chat-go/models"
+	"shudao-chat-go/utils"
+
+	beego "github.com/beego/beego/v2/server/web"
+)
+
+// ReportCompatController keeps the old AI chat frontend contract working
+// while routing the request to the service that actually implements it.
+type ReportCompatController struct {
+	beego.Controller
+}
+
+type reportCompleteFlowRequest struct {
+	UserQuestion           string `json:"user_question"`
+	WindowSize             int    `json:"window_size"`
+	NResults               int    `json:"n_results"`
+	AIConversationID       uint64 `json:"ai_conversation_id"`
+	IsNetworkSearchEnabled bool   `json:"is_network_search_enabled"`
+	EnableOnlineModel      bool   `json:"enable_online_model"`
+}
+
+type updateAIMessageRequest struct {
+	AIMessageID uint64 `json:"ai_message_id"`
+	Content     string `json:"content"`
+}
+
+type stopSSERequest struct {
+	AIConversationID uint64 `json:"ai_conversation_id"`
+}
+
+type streamChatAggregateResult struct {
+	AIConversationID uint64
+	AIMessageID      uint64
+	Content          string
+}
+
+func (c *ReportCompatController) CompleteFlow() {
+	c.setSSEHeaders()
+
+	var requestData reportCompleteFlowRequest
+	if err := json.Unmarshal(c.Ctx.Input.RequestBody, &requestData); err != nil {
+		c.writeSSEJSON(map[string]interface{}{
+			"type":    "online_error",
+			"message": fmt.Sprintf("请求参数解析失败: %s", err.Error()),
+		})
+		c.writeSSEJSON(map[string]interface{}{"type": "completed"})
+		return
+	}
+
+	userQuestion := strings.TrimSpace(requestData.UserQuestion)
+	if userQuestion == "" {
+		c.writeSSEJSON(map[string]interface{}{
+			"type":    "online_error",
+			"message": "问题不能为空",
+		})
+		c.writeSSEJSON(map[string]interface{}{"type": "completed"})
+		return
+	}
+
+	if c.shouldProxyToAIChat() {
+		if err := c.proxyAIChatSSE("/report/complete-flow", c.Ctx.Input.RequestBody); err == nil {
+			return
+		} else {
+			fmt.Printf("[report-compat] proxy to aichat failed, fallback to local stream: %v\n", err)
+		}
+	}
+
+	result, err := c.callStreamChatWithDB(requestData)
+	if err != nil {
+		c.writeSSEJSON(map[string]interface{}{
+			"type":               "online_error",
+			"ai_conversation_id": result.AIConversationID,
+			"ai_message_id":      result.AIMessageID,
+			"message":            err.Error(),
+		})
+		c.writeSSEJSON(map[string]interface{}{
+			"type":               "completed",
+			"ai_conversation_id": result.AIConversationID,
+			"ai_message_id":      result.AIMessageID,
+		})
+		return
+	}
+
+	c.writeSSEJSON(map[string]interface{}{
+		"type":               "online_answer",
+		"ai_conversation_id": result.AIConversationID,
+		"ai_message_id":      result.AIMessageID,
+		"content":            result.Content,
+	})
+	c.writeSSEJSON(map[string]interface{}{
+		"type":               "completed",
+		"ai_conversation_id": result.AIConversationID,
+		"ai_message_id":      result.AIMessageID,
+	})
+}
+
+func (c *ReportCompatController) UpdateAIMessage() {
+	if c.shouldProxyToAIChat() {
+		if err := c.proxyAIChatJSON("/report/update-ai-message", c.Ctx.Input.RequestBody); err == nil {
+			return
+		} else {
+			fmt.Printf("[report-compat] proxy update-ai-message failed, fallback to local update: %v\n", err)
+		}
+	}
+
+	var requestData updateAIMessageRequest
+	if err := json.Unmarshal(c.Ctx.Input.RequestBody, &requestData); err != nil {
+		c.Data["json"] = map[string]interface{}{
+			"success": false,
+			"message": fmt.Sprintf("请求参数解析失败: %s", err.Error()),
+		}
+		c.ServeJSON()
+		return
+	}
+
+	if requestData.AIMessageID == 0 {
+		c.Data["json"] = map[string]interface{}{
+			"success": false,
+			"message": "ai_message_id 不能为空",
+		}
+		c.ServeJSON()
+		return
+	}
+
+	if err := models.DB.Model(&models.AIMessage{}).
+		Where("id = ? AND is_deleted = ?", requestData.AIMessageID, 0).
+		Update("content", requestData.Content).Error; err != nil {
+		c.Data["json"] = map[string]interface{}{
+			"success": false,
+			"message": fmt.Sprintf("更新 AI 消息失败: %s", err.Error()),
+		}
+		c.ServeJSON()
+		return
+	}
+
+	c.Data["json"] = map[string]interface{}{
+		"success": true,
+		"message": "AI 消息已更新",
+	}
+	c.ServeJSON()
+}
+
+func (c *ReportCompatController) StopSSE() {
+	if c.shouldProxyToAIChat() {
+		if err := c.proxyAIChatJSON("/sse/stop", c.Ctx.Input.RequestBody); err == nil {
+			return
+		} else {
+			fmt.Printf("[report-compat] proxy sse/stop failed, fallback to local success response: %v\n", err)
+		}
+	}
+
+	var requestData stopSSERequest
+	_ = json.Unmarshal(c.Ctx.Input.RequestBody, &requestData)
+
+	c.Data["json"] = map[string]interface{}{
+		"success":            true,
+		"message":            "已接收停止请求",
+		"ai_conversation_id": requestData.AIConversationID,
+	}
+	c.ServeJSON()
+}
+
+func (c *ReportCompatController) setSSEHeaders() {
+	c.Ctx.ResponseWriter.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
+	c.Ctx.ResponseWriter.Header().Set("Cache-Control", "no-cache")
+	c.Ctx.ResponseWriter.Header().Set("Connection", "keep-alive")
+	c.Ctx.ResponseWriter.Header().Set("Access-Control-Allow-Origin", "*")
+	c.Ctx.ResponseWriter.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
+	c.Ctx.ResponseWriter.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Token, token")
+}
+
+func (c *ReportCompatController) writeSSEJSON(payload map[string]interface{}) {
+	responseJSON, _ := json.Marshal(payload)
+	fmt.Fprintf(c.Ctx.ResponseWriter, "data: %s\n\n", responseJSON)
+	c.Ctx.ResponseWriter.Flush()
+}
+
+func (c *ReportCompatController) shouldProxyToAIChat() bool {
+	token := c.getRequestToken()
+	if token == "" {
+		return false
+	}
+
+	if _, err := utils.VerifyLocalToken(token); err == nil {
+		return false
+	}
+
+	return true
+}
+
+func (c *ReportCompatController) getRequestToken() string {
+	for _, headerName := range []string{"token", "Token", "Authorization"} {
+		headerValue := strings.TrimSpace(c.Ctx.Request.Header.Get(headerName))
+		if headerValue == "" {
+			continue
+		}
+		if headerName == "Authorization" && strings.HasPrefix(headerValue, "Bearer ") {
+			return strings.TrimPrefix(headerValue, "Bearer ")
+		}
+		return headerValue
+	}
+
+	return ""
+}
+
+func (c *ReportCompatController) getAIChatBaseURL() string {
+	baseURL, err := beego.AppConfig.String("aichat_api_url")
+	if err != nil || strings.TrimSpace(baseURL) == "" {
+		baseURL = "http://127.0.0.1:28002/api/v1"
+	}
+
+	return strings.TrimRight(baseURL, "/")
+}
+
+func (c *ReportCompatController) proxyAIChatSSE(path string, requestBody []byte) error {
+	upstreamReq, err := http.NewRequest(
+		http.MethodPost,
+		c.getAIChatBaseURL()+path,
+		bytes.NewBuffer(requestBody),
+	)
+	if err != nil {
+		return fmt.Errorf("创建 aichat SSE 请求失败: %w", err)
+	}
+
+	upstreamReq.Header.Set("Content-Type", "application/json")
+	c.forwardAuthHeaders(upstreamReq)
+
+	client := &http.Client{Timeout: 10 * time.Minute}
+	resp, err := client.Do(upstreamReq)
+	if err != nil {
+		return fmt.Errorf("调用 aichat SSE 失败: %w", err)
+	}
+	defer resp.Body.Close()
+
+	if resp.StatusCode != http.StatusOK {
+		responseBody, _ := io.ReadAll(resp.Body)
+		return fmt.Errorf("aichat SSE 返回异常状态: %d %s", resp.StatusCode, strings.TrimSpace(string(responseBody)))
+	}
+
+	buffer := make([]byte, 4096)
+	for {
+		n, readErr := resp.Body.Read(buffer)
+		if n > 0 {
+			if _, err := c.Ctx.ResponseWriter.Write(buffer[:n]); err != nil {
+				return fmt.Errorf("写入前端 SSE 响应失败: %w", err)
+			}
+			c.Ctx.ResponseWriter.Flush()
+		}
+
+		if readErr == io.EOF {
+			return nil
+		}
+		if readErr != nil {
+			return fmt.Errorf("读取 aichat SSE 响应失败: %w", readErr)
+		}
+	}
+}
+
+func (c *ReportCompatController) proxyAIChatJSON(path string, requestBody []byte) error {
+	upstreamReq, err := http.NewRequest(
+		http.MethodPost,
+		c.getAIChatBaseURL()+path,
+		bytes.NewBuffer(requestBody),
+	)
+	if err != nil {
+		return fmt.Errorf("创建 aichat JSON 请求失败: %w", err)
+	}
+
+	upstreamReq.Header.Set("Content-Type", "application/json")
+	c.forwardAuthHeaders(upstreamReq)
+
+	client := &http.Client{Timeout: 30 * time.Second}
+	resp, err := client.Do(upstreamReq)
+	if err != nil {
+		return fmt.Errorf("调用 aichat JSON 接口失败: %w", err)
+	}
+	defer resp.Body.Close()
+
+	responseBody, err := io.ReadAll(resp.Body)
+	if err != nil {
+		return fmt.Errorf("读取 aichat JSON 响应失败: %w", err)
+	}
+
+	c.Ctx.Output.SetStatus(resp.StatusCode)
+	c.Ctx.Output.Header("Content-Type", resp.Header.Get("Content-Type"))
+	_, _ = c.Ctx.ResponseWriter.Write(responseBody)
+	return nil
+}
+
+func (c *ReportCompatController) callStreamChatWithDB(requestData reportCompleteFlowRequest) (streamChatAggregateResult, error) {
+	upstreamBody := map[string]interface{}{
+		"message":            requestData.UserQuestion,
+		"ai_conversation_id": requestData.AIConversationID,
+		"business_type":      0,
+	}
+
+	requestBody, err := json.Marshal(upstreamBody)
+	if err != nil {
+		return streamChatAggregateResult{}, fmt.Errorf("构建内部请求失败: %w", err)
+	}
+
+	httpPort, err := beego.AppConfig.Int("httpport")
+	if err != nil || httpPort == 0 {
+		httpPort = 22001
+	}
+	upstreamURL := fmt.Sprintf("http://127.0.0.1:%d/apiv1/stream/chat-with-db", httpPort)
+
+	upstreamReq, err := http.NewRequest(http.MethodPost, upstreamURL, bytes.NewBuffer(requestBody))
+	if err != nil {
+		return streamChatAggregateResult{}, fmt.Errorf("创建内部请求失败: %w", err)
+	}
+	upstreamReq.Header.Set("Content-Type", "application/json")
+	c.forwardAuthHeaders(upstreamReq)
+
+	client := &http.Client{Timeout: 10 * time.Minute}
+	resp, err := client.Do(upstreamReq)
+	if err != nil {
+		return streamChatAggregateResult{}, fmt.Errorf("调用聊天接口失败: %w", err)
+	}
+	defer resp.Body.Close()
+
+	result, parseErr := parseStreamChatResponse(resp.Body)
+	if resp.StatusCode != http.StatusOK {
+		return result, fmt.Errorf("聊天接口返回异常状态: %d", resp.StatusCode)
+	}
+	if parseErr != nil {
+		return result, parseErr
+	}
+	if strings.TrimSpace(result.Content) == "" {
+		return result, fmt.Errorf("聊天接口未返回有效内容")
+	}
+
+	return result, nil
+}
+
+func (c *ReportCompatController) forwardAuthHeaders(req *http.Request) {
+	for _, headerName := range []string{"Authorization", "Token", "token"} {
+		if headerValue := strings.TrimSpace(c.Ctx.Request.Header.Get(headerName)); headerValue != "" {
+			req.Header.Set(headerName, headerValue)
+		}
+	}
+}
+
+func parseStreamChatResponse(reader io.Reader) (streamChatAggregateResult, error) {
+	scanner := bufio.NewScanner(reader)
+	scanner.Buffer(make([]byte, 0, 64*1024), 10*1024*1024)
+
+	var result streamChatAggregateResult
+	var contentBuilder strings.Builder
+
+	for scanner.Scan() {
+		line := strings.TrimRight(scanner.Text(), "\r")
+		if strings.TrimSpace(line) == "" {
+			continue
+		}
+
+		if strings.HasPrefix(line, "data: ") {
+			data := strings.TrimPrefix(line, "data: ")
+			if data == "[DONE]" {
+				break
+			}
+
+			var payload map[string]interface{}
+			if err := json.Unmarshal([]byte(data), &payload); err == nil {
+				if result.AIConversationID == 0 {
+					result.AIConversationID = getUint64FromMap(payload, "ai_conversation_id")
+				}
+				if result.AIMessageID == 0 {
+					result.AIMessageID = getUint64FromMap(payload, "ai_message_id")
+				}
+				if errorMessage, ok := payload["error"].(string); ok && strings.TrimSpace(errorMessage) != "" {
+					result.Content = contentBuilder.String()
+					return result, fmt.Errorf("%s", errorMessage)
+				}
+				if content, ok := payload["content"].(string); ok {
+					contentBuilder.WriteString(strings.ReplaceAll(content, "\\n", "\n"))
+				}
+				continue
+			}
+
+			if errorMessage, ok := extractRawErrorMessage(data); ok {
+				result.Content = contentBuilder.String()
+				return result, fmt.Errorf("%s", errorMessage)
+			}
+
+			contentBuilder.WriteString(strings.ReplaceAll(data, "\\n", "\n"))
+			continue
+		}
+
+		contentBuilder.WriteString(strings.ReplaceAll(line, "\\n", "\n"))
+	}
+
+	if err := scanner.Err(); err != nil {
+		result.Content = contentBuilder.String()
+		return result, fmt.Errorf("读取聊天流失败: %w", err)
+	}
+
+	result.Content = contentBuilder.String()
+	return result, nil
+}
+
+func extractRawErrorMessage(data string) (string, bool) {
+	if !strings.HasPrefix(data, "{\"error\":") {
+		return "", false
+	}
+
+	errorMessage := strings.TrimPrefix(data, "{\"error\":")
+	errorMessage = strings.TrimSuffix(errorMessage, "}")
+	errorMessage = strings.TrimSpace(errorMessage)
+	errorMessage = strings.Trim(errorMessage, "\"")
+	if errorMessage == "" {
+		return "", false
+	}
+
+	return errorMessage, true
+}
+
+func getUint64FromMap(data map[string]interface{}, key string) uint64 {
+	rawValue, ok := data[key]
+	if !ok {
+		return 0
+	}
+
+	switch value := rawValue.(type) {
+	case float64:
+		return uint64(value)
+	case int:
+		return uint64(value)
+	case int64:
+		return uint64(value)
+	case uint64:
+		return value
+	default:
+		return 0
+	}
+}

+ 4 - 0
shudao-go-backend/routers/router.go

@@ -93,6 +93,10 @@ func init() {
 		beego.NSRouter("/get_chromadb_document", &controllers.ChatController{}, "get:GetChromaDBDocument"),
 		// 知识库文件高级搜索
 		beego.NSRouter("/knowledge/files/advanced-search", &controllers.ChromaController{}, "get:AdvancedSearch"),
+		// 兼容旧版 AI 问答页的 report / sse 协议
+		beego.NSRouter("/report/complete-flow", &controllers.ReportCompatController{}, "post:CompleteFlow"),
+		beego.NSRouter("/report/update-ai-message", &controllers.ReportCompatController{}, "post:UpdateAIMessage"),
+		beego.NSRouter("/sse/stop", &controllers.ReportCompatController{}, "post:StopSSE"),
 
 		// 流式接口路由
 		beego.NSRouter("/stream/chat", &controllers.LiushiController{}, "post:StreamChat"),

+ 3 - 3
shudao-vue-frontend/src/utils/api.js

@@ -3,7 +3,7 @@
  * 统一管理所有API请求
  */
 import request from './authRequest'
-import { buildApiUrl, REPORT_API_PREFIX } from './apiConfig'
+import { buildApiUrl, BACKEND_API_PREFIX } from './apiConfig'
 import { getToken } from './auth'
 
 /**
@@ -89,7 +89,7 @@ export async function stopSSEStream(userId, aiConversationId) {
       headers['Authorization'] = `Bearer ${token}`
     }
     
-    const response = await fetch(buildApiUrl('/sse/stop', REPORT_API_PREFIX), {
+    const response = await fetch(buildApiUrl('/sse/stop', BACKEND_API_PREFIX), {
       method: 'POST',
       headers,
       body: JSON.stringify({
@@ -154,7 +154,7 @@ export async function updateAIMessageContent(aiMessageId, content) {
       headers['Authorization'] = `Bearer ${token}`
     }
     
-    const response = await fetch(buildApiUrl('/report/update-ai-message', REPORT_API_PREFIX), {
+    const response = await fetch(buildApiUrl('/report/update-ai-message', BACKEND_API_PREFIX), {
       method: 'POST',
       headers,
       body: JSON.stringify({

+ 2 - 2
shudao-vue-frontend/src/views/mobile/m-Chat.vue

@@ -467,7 +467,7 @@ import { apis } from '@/request/apis.js'
 // import { getUserId } from '@/utils/userManager.js'
 import { useSpeechRecognition } from '@/composables/useSpeechRecognition'
 import { createSSEConnection, closeSSEConnection } from '@/utils/sse'
-import { getApiPrefix, getReportApiPrefix, BACKEND_API_PREFIX } from '@/utils/apiConfig'
+import { getApiPrefix, BACKEND_API_PREFIX } from '@/utils/apiConfig'
 import { renderMarkdown } from '@/utils/markdown'
 import { stopSSEStream, updateAIMessageContent } from '@/utils/api.js'
 import { getToken, getTokenType, getUserName, getAccountId } from '@/utils/auth.js'
@@ -3218,7 +3218,7 @@ const handleReportGeneratorSubmit = async (data) => {
   })
 
   try {
-    const apiPrefix = getReportApiPrefix()
+    const apiPrefix = getApiPrefix()
     const url = `${apiPrefix}/report/complete-flow`
 
     // 构建 POST 请求体