| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448 |
- package controllers
- import (
- "bufio"
- "bytes"
- "encoding/json"
- "fmt"
- "io"
- "net/http"
- "strings"
- "time"
- "shudao-chat-go/models"
- "shudao-chat-go/utils"
- beego "github.com/beego/beego/v2/server/web"
- )
- // ReportCompatController keeps the old AI chat frontend contract working
- // while routing the request to the service that actually implements it.
- type ReportCompatController struct {
- beego.Controller
- }
- type reportCompleteFlowRequest struct {
- UserQuestion string `json:"user_question"`
- WindowSize int `json:"window_size"`
- NResults int `json:"n_results"`
- AIConversationID uint64 `json:"ai_conversation_id"`
- IsNetworkSearchEnabled bool `json:"is_network_search_enabled"`
- EnableOnlineModel bool `json:"enable_online_model"`
- }
- type updateAIMessageRequest struct {
- AIMessageID uint64 `json:"ai_message_id"`
- Content string `json:"content"`
- }
- type stopSSERequest struct {
- AIConversationID uint64 `json:"ai_conversation_id"`
- }
- type streamChatAggregateResult struct {
- AIConversationID uint64
- AIMessageID uint64
- Content string
- }
- func (c *ReportCompatController) CompleteFlow() {
- c.setSSEHeaders()
- var requestData reportCompleteFlowRequest
- if err := json.Unmarshal(c.Ctx.Input.RequestBody, &requestData); err != nil {
- c.writeSSEJSON(map[string]interface{}{
- "type": "online_error",
- "message": fmt.Sprintf("请求参数解析失败: %s", err.Error()),
- })
- c.writeSSEJSON(map[string]interface{}{"type": "completed"})
- return
- }
- userQuestion := strings.TrimSpace(requestData.UserQuestion)
- if userQuestion == "" {
- c.writeSSEJSON(map[string]interface{}{
- "type": "online_error",
- "message": "问题不能为空",
- })
- c.writeSSEJSON(map[string]interface{}{"type": "completed"})
- return
- }
- if c.shouldProxyToAIChat() {
- if err := c.proxyAIChatSSE("/report/complete-flow", c.Ctx.Input.RequestBody); err == nil {
- return
- } else {
- fmt.Printf("[report-compat] proxy to aichat failed, fallback to local stream: %v\n", err)
- }
- }
- result, err := c.callStreamChatWithDB(requestData)
- if err != nil {
- c.writeSSEJSON(map[string]interface{}{
- "type": "online_error",
- "ai_conversation_id": result.AIConversationID,
- "ai_message_id": result.AIMessageID,
- "message": err.Error(),
- })
- c.writeSSEJSON(map[string]interface{}{
- "type": "completed",
- "ai_conversation_id": result.AIConversationID,
- "ai_message_id": result.AIMessageID,
- })
- return
- }
- c.writeSSEJSON(map[string]interface{}{
- "type": "online_answer",
- "ai_conversation_id": result.AIConversationID,
- "ai_message_id": result.AIMessageID,
- "content": result.Content,
- })
- c.writeSSEJSON(map[string]interface{}{
- "type": "completed",
- "ai_conversation_id": result.AIConversationID,
- "ai_message_id": result.AIMessageID,
- })
- }
- func (c *ReportCompatController) UpdateAIMessage() {
- if c.shouldProxyToAIChat() {
- if err := c.proxyAIChatJSON("/report/update-ai-message", c.Ctx.Input.RequestBody); err == nil {
- return
- } else {
- fmt.Printf("[report-compat] proxy update-ai-message failed, fallback to local update: %v\n", err)
- }
- }
- var requestData updateAIMessageRequest
- if err := json.Unmarshal(c.Ctx.Input.RequestBody, &requestData); err != nil {
- c.Data["json"] = map[string]interface{}{
- "success": false,
- "message": fmt.Sprintf("请求参数解析失败: %s", err.Error()),
- }
- c.ServeJSON()
- return
- }
- if requestData.AIMessageID == 0 {
- c.Data["json"] = map[string]interface{}{
- "success": false,
- "message": "ai_message_id 不能为空",
- }
- c.ServeJSON()
- return
- }
- if err := models.DB.Model(&models.AIMessage{}).
- Where("id = ? AND is_deleted = ?", requestData.AIMessageID, 0).
- Update("content", requestData.Content).Error; err != nil {
- c.Data["json"] = map[string]interface{}{
- "success": false,
- "message": fmt.Sprintf("更新 AI 消息失败: %s", err.Error()),
- }
- c.ServeJSON()
- return
- }
- c.Data["json"] = map[string]interface{}{
- "success": true,
- "message": "AI 消息已更新",
- }
- c.ServeJSON()
- }
- func (c *ReportCompatController) StopSSE() {
- if c.shouldProxyToAIChat() {
- if err := c.proxyAIChatJSON("/sse/stop", c.Ctx.Input.RequestBody); err == nil {
- return
- } else {
- fmt.Printf("[report-compat] proxy sse/stop failed, fallback to local success response: %v\n", err)
- }
- }
- var requestData stopSSERequest
- _ = json.Unmarshal(c.Ctx.Input.RequestBody, &requestData)
- c.Data["json"] = map[string]interface{}{
- "success": true,
- "message": "已接收停止请求",
- "ai_conversation_id": requestData.AIConversationID,
- }
- c.ServeJSON()
- }
- func (c *ReportCompatController) setSSEHeaders() {
- c.Ctx.ResponseWriter.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
- c.Ctx.ResponseWriter.Header().Set("Cache-Control", "no-cache")
- c.Ctx.ResponseWriter.Header().Set("Connection", "keep-alive")
- c.Ctx.ResponseWriter.Header().Set("Access-Control-Allow-Origin", "*")
- c.Ctx.ResponseWriter.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS")
- c.Ctx.ResponseWriter.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization, Token, token")
- }
- func (c *ReportCompatController) writeSSEJSON(payload map[string]interface{}) {
- responseJSON, _ := json.Marshal(payload)
- fmt.Fprintf(c.Ctx.ResponseWriter, "data: %s\n\n", responseJSON)
- c.Ctx.ResponseWriter.Flush()
- }
- func (c *ReportCompatController) shouldProxyToAIChat() bool {
- token := c.getRequestToken()
- if token == "" {
- return false
- }
- if _, err := utils.VerifyLocalToken(token); err == nil {
- return false
- }
- return true
- }
- func (c *ReportCompatController) getRequestToken() string {
- for _, headerName := range []string{"token", "Token", "Authorization"} {
- headerValue := strings.TrimSpace(c.Ctx.Request.Header.Get(headerName))
- if headerValue == "" {
- continue
- }
- if headerName == "Authorization" && strings.HasPrefix(headerValue, "Bearer ") {
- return strings.TrimPrefix(headerValue, "Bearer ")
- }
- return headerValue
- }
- return ""
- }
- func (c *ReportCompatController) getAIChatBaseURL() string {
- baseURL, err := beego.AppConfig.String("aichat_api_url")
- if err != nil || strings.TrimSpace(baseURL) == "" {
- baseURL = "http://127.0.0.1:28002/api/v1"
- }
- return strings.TrimRight(baseURL, "/")
- }
- func (c *ReportCompatController) proxyAIChatSSE(path string, requestBody []byte) error {
- upstreamReq, err := http.NewRequest(
- http.MethodPost,
- c.getAIChatBaseURL()+path,
- bytes.NewBuffer(requestBody),
- )
- if err != nil {
- return fmt.Errorf("创建 aichat SSE 请求失败: %w", err)
- }
- upstreamReq.Header.Set("Content-Type", "application/json")
- c.forwardAuthHeaders(upstreamReq)
- client := &http.Client{Timeout: 10 * time.Minute}
- resp, err := client.Do(upstreamReq)
- if err != nil {
- return fmt.Errorf("调用 aichat SSE 失败: %w", err)
- }
- defer resp.Body.Close()
- if resp.StatusCode != http.StatusOK {
- responseBody, _ := io.ReadAll(resp.Body)
- return fmt.Errorf("aichat SSE 返回异常状态: %d %s", resp.StatusCode, strings.TrimSpace(string(responseBody)))
- }
- buffer := make([]byte, 4096)
- for {
- n, readErr := resp.Body.Read(buffer)
- if n > 0 {
- if _, err := c.Ctx.ResponseWriter.Write(buffer[:n]); err != nil {
- return fmt.Errorf("写入前端 SSE 响应失败: %w", err)
- }
- c.Ctx.ResponseWriter.Flush()
- }
- if readErr == io.EOF {
- return nil
- }
- if readErr != nil {
- return fmt.Errorf("读取 aichat SSE 响应失败: %w", readErr)
- }
- }
- }
- func (c *ReportCompatController) proxyAIChatJSON(path string, requestBody []byte) error {
- upstreamReq, err := http.NewRequest(
- http.MethodPost,
- c.getAIChatBaseURL()+path,
- bytes.NewBuffer(requestBody),
- )
- if err != nil {
- return fmt.Errorf("创建 aichat JSON 请求失败: %w", err)
- }
- upstreamReq.Header.Set("Content-Type", "application/json")
- c.forwardAuthHeaders(upstreamReq)
- client := &http.Client{Timeout: 30 * time.Second}
- resp, err := client.Do(upstreamReq)
- if err != nil {
- return fmt.Errorf("调用 aichat JSON 接口失败: %w", err)
- }
- defer resp.Body.Close()
- responseBody, err := io.ReadAll(resp.Body)
- if err != nil {
- return fmt.Errorf("读取 aichat JSON 响应失败: %w", err)
- }
- c.Ctx.Output.SetStatus(resp.StatusCode)
- c.Ctx.Output.Header("Content-Type", resp.Header.Get("Content-Type"))
- _, _ = c.Ctx.ResponseWriter.Write(responseBody)
- return nil
- }
- func (c *ReportCompatController) callStreamChatWithDB(requestData reportCompleteFlowRequest) (streamChatAggregateResult, error) {
- upstreamBody := map[string]interface{}{
- "message": requestData.UserQuestion,
- "ai_conversation_id": requestData.AIConversationID,
- "business_type": 0,
- }
- requestBody, err := json.Marshal(upstreamBody)
- if err != nil {
- return streamChatAggregateResult{}, fmt.Errorf("构建内部请求失败: %w", err)
- }
- httpPort, err := beego.AppConfig.Int("httpport")
- if err != nil || httpPort == 0 {
- httpPort = 22001
- }
- upstreamURL := fmt.Sprintf("http://127.0.0.1:%d/apiv1/stream/chat-with-db", httpPort)
- upstreamReq, err := http.NewRequest(http.MethodPost, upstreamURL, bytes.NewBuffer(requestBody))
- if err != nil {
- return streamChatAggregateResult{}, fmt.Errorf("创建内部请求失败: %w", err)
- }
- upstreamReq.Header.Set("Content-Type", "application/json")
- c.forwardAuthHeaders(upstreamReq)
- client := &http.Client{Timeout: 10 * time.Minute}
- resp, err := client.Do(upstreamReq)
- if err != nil {
- return streamChatAggregateResult{}, fmt.Errorf("调用聊天接口失败: %w", err)
- }
- defer resp.Body.Close()
- result, parseErr := parseStreamChatResponse(resp.Body)
- if resp.StatusCode != http.StatusOK {
- return result, fmt.Errorf("聊天接口返回异常状态: %d", resp.StatusCode)
- }
- if parseErr != nil {
- return result, parseErr
- }
- if strings.TrimSpace(result.Content) == "" {
- return result, fmt.Errorf("聊天接口未返回有效内容")
- }
- return result, nil
- }
- func (c *ReportCompatController) forwardAuthHeaders(req *http.Request) {
- for _, headerName := range []string{"Authorization", "Token", "token"} {
- if headerValue := strings.TrimSpace(c.Ctx.Request.Header.Get(headerName)); headerValue != "" {
- req.Header.Set(headerName, headerValue)
- }
- }
- }
- func parseStreamChatResponse(reader io.Reader) (streamChatAggregateResult, error) {
- scanner := bufio.NewScanner(reader)
- scanner.Buffer(make([]byte, 0, 64*1024), 10*1024*1024)
- var result streamChatAggregateResult
- var contentBuilder strings.Builder
- for scanner.Scan() {
- line := strings.TrimRight(scanner.Text(), "\r")
- if strings.TrimSpace(line) == "" {
- continue
- }
- if strings.HasPrefix(line, "data: ") {
- data := strings.TrimPrefix(line, "data: ")
- if data == "[DONE]" {
- break
- }
- var payload map[string]interface{}
- if err := json.Unmarshal([]byte(data), &payload); err == nil {
- if result.AIConversationID == 0 {
- result.AIConversationID = getUint64FromMap(payload, "ai_conversation_id")
- }
- if result.AIMessageID == 0 {
- result.AIMessageID = getUint64FromMap(payload, "ai_message_id")
- }
- if errorMessage, ok := payload["error"].(string); ok && strings.TrimSpace(errorMessage) != "" {
- result.Content = contentBuilder.String()
- return result, fmt.Errorf("%s", errorMessage)
- }
- if content, ok := payload["content"].(string); ok {
- contentBuilder.WriteString(strings.ReplaceAll(content, "\\n", "\n"))
- }
- continue
- }
- if errorMessage, ok := extractRawErrorMessage(data); ok {
- result.Content = contentBuilder.String()
- return result, fmt.Errorf("%s", errorMessage)
- }
- contentBuilder.WriteString(strings.ReplaceAll(data, "\\n", "\n"))
- continue
- }
- contentBuilder.WriteString(strings.ReplaceAll(line, "\\n", "\n"))
- }
- if err := scanner.Err(); err != nil {
- result.Content = contentBuilder.String()
- return result, fmt.Errorf("读取聊天流失败: %w", err)
- }
- result.Content = contentBuilder.String()
- return result, nil
- }
- func extractRawErrorMessage(data string) (string, bool) {
- if !strings.HasPrefix(data, "{\"error\":") {
- return "", false
- }
- errorMessage := strings.TrimPrefix(data, "{\"error\":")
- errorMessage = strings.TrimSuffix(errorMessage, "}")
- errorMessage = strings.TrimSpace(errorMessage)
- errorMessage = strings.Trim(errorMessage, "\"")
- if errorMessage == "" {
- return "", false
- }
- return errorMessage, true
- }
- func getUint64FromMap(data map[string]interface{}, key string) uint64 {
- rawValue, ok := data[key]
- if !ok {
- return 0
- }
- switch value := rawValue.(type) {
- case float64:
- return uint64(value)
- case int:
- return uint64(value)
- case int64:
- return uint64(value)
- case uint64:
- return value
- default:
- return 0
- }
- }
|