chroma.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. package controllers
  2. import (
  3. "encoding/json"
  4. "fmt"
  5. "io"
  6. "net/http"
  7. "net/url"
  8. "shudao-chat-go/utils"
  9. "strconv"
  10. "time"
  11. "github.com/beego/beego/v2/server/web"
  12. )
  13. // ChromaController 知识库搜索控制器
  14. type ChromaController struct {
  15. web.Controller
  16. }
  17. // SearchRequest 搜索请求结构体
  18. type SearchRequest struct {
  19. QueryStr string `json:"query_str"`
  20. NResults int `json:"n_results"`
  21. }
  22. // SearchResponse 搜索响应结构体
  23. type SearchResponse struct {
  24. Total int `json:"total"`
  25. Files []FileInfo `json:"files"`
  26. }
  27. // FileInfo 文件信息结构体
  28. type FileInfo struct {
  29. DocumentID string `json:"document_id"`
  30. FileName string `json:"file_name"`
  31. FilePath string `json:"file_path"`
  32. SourceFile string `json:"source_file"`
  33. HasTitle bool `json:"has_title"`
  34. }
  35. // AdvancedSearch 知识库文件高级搜索接口
  36. func (c *ChromaController) AdvancedSearch() {
  37. // 获取查询参数
  38. queryStr := c.GetString("query_str")
  39. nResultsStr := c.GetString("n_results", "50")
  40. // 参数验证
  41. if queryStr == "" {
  42. c.Data["json"] = map[string]interface{}{
  43. "code": 400,
  44. "message": "查询字符串不能为空",
  45. "data": nil,
  46. }
  47. c.ServeJSON()
  48. return
  49. }
  50. // 转换n_results参数
  51. nResults, err := strconv.Atoi(nResultsStr)
  52. if err != nil || nResults <= 0 {
  53. nResults = 50 // 默认值
  54. }
  55. // 构建请求参数
  56. searchReq := SearchRequest{
  57. QueryStr: queryStr,
  58. NResults: nResults,
  59. }
  60. // 调用外部API
  61. result, err := c.callAdvancedSearchAPI(searchReq)
  62. if err != nil {
  63. c.Data["json"] = map[string]interface{}{
  64. "code": 500,
  65. "message": fmt.Sprintf("搜索失败: %v", err),
  66. "data": nil,
  67. }
  68. c.ServeJSON()
  69. return
  70. }
  71. fmt.Println("知识库:", result)
  72. // 返回成功结果
  73. c.Data["json"] = map[string]interface{}{
  74. "code": 200,
  75. "message": "搜索成功",
  76. "data": result,
  77. }
  78. c.ServeJSON()
  79. }
  80. // callAdvancedSearchAPI 调用外部高级搜索API
  81. func (c *ChromaController) callAdvancedSearchAPI(req SearchRequest) (*SearchResponse, error) {
  82. // 从配置文件获取API URL
  83. apiURL := utils.GetKnowledgeSearchURL()
  84. // 构建查询参数
  85. params := url.Values{}
  86. params.Add("query_str", req.QueryStr)
  87. params.Add("n_results", strconv.Itoa(req.NResults))
  88. // 创建HTTP请求
  89. fullURL := apiURL + "?" + params.Encode()
  90. httpReq, err := http.NewRequest("GET", fullURL, nil)
  91. if err != nil {
  92. return nil, fmt.Errorf("创建请求失败: %v", err)
  93. }
  94. // 设置请求头
  95. httpReq.Header.Set("Content-Type", "application/json")
  96. httpReq.Header.Set("User-Agent", "shudao-chat-go/1.0")
  97. // 创建HTTP客户端
  98. client := &http.Client{
  99. Timeout: 30 * time.Second,
  100. }
  101. // 发送请求
  102. resp, err := client.Do(httpReq)
  103. if err != nil {
  104. return nil, fmt.Errorf("请求失败: %v", err)
  105. }
  106. defer resp.Body.Close()
  107. // 检查响应状态码
  108. if resp.StatusCode != http.StatusOK {
  109. return nil, fmt.Errorf("API返回错误状态码: %d", resp.StatusCode)
  110. }
  111. // 读取响应体
  112. body, err := io.ReadAll(resp.Body)
  113. if err != nil {
  114. return nil, fmt.Errorf("读取响应失败: %v", err)
  115. }
  116. // 解析响应JSON
  117. var searchResp SearchResponse
  118. err = json.Unmarshal(body, &searchResp)
  119. if err != nil {
  120. return nil, fmt.Errorf("解析响应JSON失败: %v", err)
  121. }
  122. return &searchResp, nil
  123. }