sample_center_client.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138
  1. import time
  2. import threading
  3. import logging
  4. import requests
  5. from datetime import datetime, timedelta
  6. logger = logging.getLogger(__name__)
  7. class SampleCenterError(Exception):
  8. """样本中心交互异常。"""
  9. pass
  10. class SampleCenterClient:
  11. """样本中心 API 客户端,封装 token 管理和接口调用。"""
  12. TOKEN_REFRESH_THRESHOLD = 300 # token 过期前 5 分钟自动刷新
  13. def __init__(self, base_url, app_id, app_secret):
  14. self.base_url = base_url.rstrip("/")
  15. self.app_id = app_id
  16. self.app_secret = app_secret
  17. self._access_token = None
  18. self._token_expires_at = None
  19. self._lock = threading.Lock()
  20. def _ensure_token(self):
  21. """线程安全地确保 access_token 有效。"""
  22. now = datetime.utcnow()
  23. if (self._access_token and self._token_expires_at
  24. and now < self._token_expires_at - timedelta(seconds=self.TOKEN_REFRESH_THRESHOLD)):
  25. return self._access_token
  26. with self._lock:
  27. now = datetime.utcnow()
  28. if (self._access_token and self._token_expires_at
  29. and now < self._token_expires_at - timedelta(seconds=self.TOKEN_REFRESH_THRESHOLD)):
  30. return self._access_token
  31. self._refresh_token()
  32. return self._access_token
  33. def _refresh_token(self):
  34. """调用样本中心换 token 接口。"""
  35. url = f"{self.base_url}/api/v1/auth/token"
  36. logger.info(f"请求 Token: url={url}, app_id={self.app_id}")
  37. try:
  38. resp = requests.post(url, json={
  39. "app_id": self.app_id,
  40. "app_secret": self.app_secret,
  41. }, timeout=15)
  42. logger.info(f"Token 响应: status={resp.status_code}, body={resp.text[:500]}")
  43. if resp.status_code != 200:
  44. raise SampleCenterError(f"Token request failed: {resp.status_code} {resp.text}")
  45. result = resp.json()
  46. if result.get("code") != "000000":
  47. raise SampleCenterError(f"Token error: {result.get('message', result)}")
  48. data = result["data"]
  49. self._access_token = data["access_token"]
  50. expires_in = data.get("expires_in", 7200)
  51. self._token_expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
  52. logger.info(f"Token 获取成功: expires_in={expires_in}")
  53. except requests.exceptions.RequestException as e:
  54. logger.error(f"Token 请求异常: {e}")
  55. raise SampleCenterError(f"Token request exception: {e}")
  56. def _headers(self):
  57. return {
  58. "Authorization": f"Bearer {self._ensure_token()}",
  59. "X-App-Id": self.app_id,
  60. }
  61. def _parse(self, resp):
  62. """统一解析样本中心响应。"""
  63. if resp.status_code != 200:
  64. logger.error(f"HTTP 错误: status={resp.status_code}, body={resp.text[:500]}")
  65. raise SampleCenterError(f"HTTP {resp.status_code}: {resp.text}")
  66. body = resp.json()
  67. if body.get("code") != "000000":
  68. logger.error(f"业务错误: code={body.get('code')}, message={body.get('message')}")
  69. raise SampleCenterError(
  70. f"SampleCenter error [{body.get('code')}]: {body.get('message', 'unknown')}"
  71. )
  72. return body
  73. # -- 业务接口 --
  74. def list_knowledge_bases(self, page=1, page_size=20):
  75. """GET /api/v1/knowledge-bases -- 知识库列表。"""
  76. url = f"{self.base_url}/api/v1/knowledge-bases"
  77. logger.info(f"请求知识库列表: url={url}, page={page}, page_size={page_size}")
  78. resp = requests.get(
  79. url,
  80. headers=self._headers(),
  81. params={"page": page, "page_size": page_size},
  82. timeout=30,
  83. )
  84. logger.info(f"知识库列表响应: status={resp.status_code}, body={resp.text[:500]}")
  85. return self._parse(resp)
  86. def get_knowledge_base(self, kb_id):
  87. """GET /api/v1/knowledge-bases/{id} -- 知识库详情。"""
  88. url = f"{self.base_url}/api/v1/knowledge-bases/{kb_id}"
  89. logger.info(f"请求知识库详情: url={url}")
  90. resp = requests.get(url, headers=self._headers(), timeout=15)
  91. return self._parse(resp)
  92. def batch_import(self, kb_id, task_no, parents=None, children=None, callback_url=None):
  93. """POST /api/v1/knowledge-bases/{kb_id}/batch-import -- 批量入库。"""
  94. url = f"{self.base_url}/api/v1/knowledge-bases/{kb_id}/batch-import"
  95. payload = {"task_no": task_no}
  96. if parents:
  97. payload["parents"] = parents
  98. if children:
  99. payload["children"] = children
  100. if callback_url:
  101. payload["callback_url"] = callback_url
  102. logger.info(
  103. f"请求批量入库: url={url}, task_no={task_no}, "
  104. f"parents_count={len(parents) if parents else 0}, children_count={len(children) if children else 0}, "
  105. f"parents={parents!r}, children={children!r}, callback_url={callback_url}"
  106. )
  107. resp = requests.post(
  108. url,
  109. headers=self._headers(),
  110. json=payload,
  111. timeout=60,
  112. )
  113. logger.info(f"批量入库响应: status={resp.status_code}, body={resp.text[:1000]}")
  114. return self._parse(resp)
  115. def get_import_task(self, task_id):
  116. """GET /api/v1/knowledge-bases/batch-import/{task_id} -- 入库任务查询。"""
  117. url = f"{self.base_url}/api/v1/knowledge-bases/batch-import/{task_id}"
  118. logger.info(f"请求入库任务: url={url}")
  119. resp = requests.get(url, headers=self._headers(), timeout=15)
  120. logger.info(f"入库任务响应: url={url}, status={resp.status_code}, body={resp.text[:1000]}")
  121. return self._parse(resp)