chroma.go 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. package controllers
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "net/url"
  8. "shudao-chat-go/utils"
  9. "strconv"
  10. "time"
  11. "github.com/beego/beego/v2/server/web"
  12. )
  13. // chromaSearchHeartbeatTask 心跳任务
  14. type chromaSearchHeartbeatTask struct {
  15. url string
  16. interval time.Duration
  17. httpClient *http.Client
  18. stopChan chan struct{}
  19. }
  20. // ChromaController 知识库搜索控制器
  21. type ChromaController struct {
  22. web.Controller
  23. }
  24. // SearchRequest 搜索请求结构体
  25. type SearchRequest struct {
  26. QueryStr string `json:"query_str"`
  27. NResults int `json:"n_results"`
  28. }
  29. // SearchResponse 搜索响应结构体
  30. type SearchResponse struct {
  31. Total int `json:"total"`
  32. Files []FileInfo `json:"files"`
  33. }
  34. // FileInfo 文件信息结构体
  35. type FileInfo struct {
  36. DocumentID string `json:"document_id"`
  37. FileName string `json:"file_name"`
  38. FilePath string `json:"file_path"`
  39. SourceFile string `json:"source_file"`
  40. HasTitle bool `json:"has_title"`
  41. }
  42. // AdvancedSearch 知识库文件高级搜索接口
  43. func (c *ChromaController) AdvancedSearch() {
  44. // 获取查询参数
  45. queryStr := c.GetString("query_str")
  46. nResultsStr := c.GetString("n_results", "50")
  47. // 参数验证
  48. if queryStr == "" {
  49. c.Data["json"] = map[string]interface{}{
  50. "code": 400,
  51. "message": "查询字符串不能为空",
  52. "data": nil,
  53. }
  54. c.ServeJSON()
  55. return
  56. }
  57. // 转换n_results参数
  58. nResults, err := strconv.Atoi(nResultsStr)
  59. if err != nil || nResults <= 0 {
  60. nResults = 50 // 默认值
  61. }
  62. // 构建请求参数
  63. searchReq := SearchRequest{
  64. QueryStr: queryStr,
  65. NResults: nResults,
  66. }
  67. // 调用外部API
  68. result, err := c.callAdvancedSearchAPI(searchReq)
  69. if err != nil {
  70. c.Data["json"] = map[string]interface{}{
  71. "code": 500,
  72. "message": fmt.Sprintf("搜索失败: %v", err),
  73. "data": nil,
  74. }
  75. c.ServeJSON()
  76. return
  77. }
  78. fmt.Println("知识库:", result)
  79. // 返回成功结果
  80. c.Data["json"] = map[string]interface{}{
  81. "code": 200,
  82. "message": "搜索成功",
  83. "data": result,
  84. }
  85. c.ServeJSON()
  86. }
  87. // callAdvancedSearchAPI 调用外部高级搜索API
  88. func (c *ChromaController) callAdvancedSearchAPI(req SearchRequest) (*SearchResponse, error) {
  89. // 从配置文件获取API URL
  90. apiURL := utils.GetKnowledgeSearchURL()
  91. // 构建查询参数
  92. params := url.Values{}
  93. params.Add("query_str", req.QueryStr)
  94. params.Add("n_results", strconv.Itoa(req.NResults))
  95. // 创建HTTP请求
  96. fullURL := apiURL + "?" + params.Encode()
  97. httpReq, err := http.NewRequest("GET", fullURL, nil)
  98. if err != nil {
  99. return nil, fmt.Errorf("创建请求失败: %v", err)
  100. }
  101. // 设置请求头
  102. httpReq.Header.Set("Content-Type", "application/json")
  103. httpReq.Header.Set("User-Agent", "shudao-chat-go/1.0")
  104. // 创建HTTP客户端
  105. client := &http.Client{
  106. Timeout: 30 * time.Second,
  107. }
  108. // 发送请求
  109. resp, err := client.Do(httpReq)
  110. if err != nil {
  111. return nil, fmt.Errorf("请求失败: %v", err)
  112. }
  113. defer resp.Body.Close()
  114. // 检查响应状态码
  115. if resp.StatusCode != http.StatusOK {
  116. return nil, fmt.Errorf("API返回错误状态码: %d", resp.StatusCode)
  117. }
  118. // 读取响应体
  119. body, err := io.ReadAll(resp.Body)
  120. if err != nil {
  121. return nil, fmt.Errorf("读取响应失败: %v", err)
  122. }
  123. // 解析响应JSON
  124. var searchResp SearchResponse
  125. err = json.Unmarshal(body, &searchResp)
  126. if err != nil {
  127. return nil, fmt.Errorf("解析响应JSON失败: %v", err)
  128. }
  129. return &searchResp, nil
  130. }
  131. // StartChromaSearchHeartbeatTask 启动ChromaDB搜索服务心跳任务
  132. func StartChromaSearchHeartbeatTask() {
  133. heartbeatURL := utils.GetConfigString("heartbeat_api_url", "")
  134. if heartbeatURL == "" {
  135. return
  136. }
  137. task := &chromaSearchHeartbeatTask{
  138. url: heartbeatURL,
  139. interval: 10 * time.Minute,
  140. httpClient: &http.Client{Timeout: 30 * time.Second},
  141. stopChan: make(chan struct{}),
  142. }
  143. go task.run()
  144. }
  145. func (h *chromaSearchHeartbeatTask) run() {
  146. ticker := time.NewTicker(h.interval)
  147. defer ticker.Stop()
  148. h.sendHeartbeat()
  149. for {
  150. select {
  151. case <-h.stopChan:
  152. return
  153. case <-ticker.C:
  154. h.sendHeartbeat()
  155. }
  156. }
  157. }
  158. func (h *chromaSearchHeartbeatTask) sendHeartbeat() {
  159. resp, err := h.httpClient.Get(h.url)
  160. if err != nil {
  161. return
  162. }
  163. defer resp.Body.Close()
  164. io.ReadAll(resp.Body)
  165. }