| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138 |
- import time
- import threading
- import logging
- import requests
- from datetime import datetime, timedelta
- logger = logging.getLogger(__name__)
- class SampleCenterError(Exception):
- """样本中心交互异常。"""
- pass
- class SampleCenterClient:
- """样本中心 API 客户端,封装 token 管理和接口调用。"""
- TOKEN_REFRESH_THRESHOLD = 300 # token 过期前 5 分钟自动刷新
- def __init__(self, base_url, app_id, app_secret):
- self.base_url = base_url.rstrip("/")
- self.app_id = app_id
- self.app_secret = app_secret
- self._access_token = None
- self._token_expires_at = None
- self._lock = threading.Lock()
- def _ensure_token(self):
- """线程安全地确保 access_token 有效。"""
- now = datetime.utcnow()
- if (self._access_token and self._token_expires_at
- and now < self._token_expires_at - timedelta(seconds=self.TOKEN_REFRESH_THRESHOLD)):
- return self._access_token
- with self._lock:
- now = datetime.utcnow()
- if (self._access_token and self._token_expires_at
- and now < self._token_expires_at - timedelta(seconds=self.TOKEN_REFRESH_THRESHOLD)):
- return self._access_token
- self._refresh_token()
- return self._access_token
- def _refresh_token(self):
- """调用样本中心换 token 接口。"""
- url = f"{self.base_url}/api/v1/auth/token"
- logger.info(f"请求 Token: url={url}, app_id={self.app_id}")
- try:
- resp = requests.post(url, json={
- "app_id": self.app_id,
- "app_secret": self.app_secret,
- }, timeout=15)
- logger.info(f"Token 响应: status={resp.status_code}, body={resp.text[:500]}")
- if resp.status_code != 200:
- raise SampleCenterError(f"Token request failed: {resp.status_code} {resp.text}")
- result = resp.json()
- if result.get("code") != "000000":
- raise SampleCenterError(f"Token error: {result.get('message', result)}")
- data = result["data"]
- self._access_token = data["access_token"]
- expires_in = data.get("expires_in", 7200)
- self._token_expires_at = datetime.utcnow() + timedelta(seconds=expires_in)
- logger.info(f"Token 获取成功: expires_in={expires_in}")
- except requests.exceptions.RequestException as e:
- logger.error(f"Token 请求异常: {e}")
- raise SampleCenterError(f"Token request exception: {e}")
- def _headers(self):
- return {
- "Authorization": f"Bearer {self._ensure_token()}",
- "X-App-Id": self.app_id,
- }
- def _parse(self, resp):
- """统一解析样本中心响应。"""
- if resp.status_code != 200:
- logger.error(f"HTTP 错误: status={resp.status_code}, body={resp.text[:500]}")
- raise SampleCenterError(f"HTTP {resp.status_code}: {resp.text}")
- body = resp.json()
- if body.get("code") != "000000":
- logger.error(f"业务错误: code={body.get('code')}, message={body.get('message')}")
- raise SampleCenterError(
- f"SampleCenter error [{body.get('code')}]: {body.get('message', 'unknown')}"
- )
- return body
- # -- 业务接口 --
- def list_knowledge_bases(self, page=1, page_size=20):
- """GET /api/v1/knowledge-bases -- 知识库列表。"""
- url = f"{self.base_url}/api/v1/knowledge-bases"
- logger.info(f"请求知识库列表: url={url}, page={page}, page_size={page_size}")
- resp = requests.get(
- url,
- headers=self._headers(),
- params={"page": page, "page_size": page_size},
- timeout=30,
- )
- logger.info(f"知识库列表响应: status={resp.status_code}, body={resp.text[:500]}")
- return self._parse(resp)
- def get_knowledge_base(self, kb_id):
- """GET /api/v1/knowledge-bases/{id} -- 知识库详情。"""
- url = f"{self.base_url}/api/v1/knowledge-bases/{kb_id}"
- logger.info(f"请求知识库详情: url={url}")
- resp = requests.get(url, headers=self._headers(), timeout=15)
- return self._parse(resp)
- def batch_import(self, kb_id, task_no, parents=None, children=None, callback_url=None):
- """POST /api/v1/knowledge-bases/{kb_id}/batch-import -- 批量入库。"""
- url = f"{self.base_url}/api/v1/knowledge-bases/{kb_id}/batch-import"
- payload = {"task_no": task_no}
- if parents:
- payload["parents"] = parents
- if children:
- payload["children"] = children
- if callback_url:
- payload["callback_url"] = callback_url
- logger.info(
- f"请求批量入库: url={url}, task_no={task_no}, "
- f"parents_count={len(parents) if parents else 0}, children_count={len(children) if children else 0}, "
- f"parents={parents!r}, children={children!r}, callback_url={callback_url}"
- )
- resp = requests.post(
- url,
- headers=self._headers(),
- json=payload,
- timeout=60,
- )
- logger.info(f"批量入库响应: status={resp.status_code}, body={resp.text[:1000]}")
- return self._parse(resp)
- def get_import_task(self, task_id):
- """GET /api/v1/knowledge-bases/batch-import/{task_id} -- 入库任务查询。"""
- url = f"{self.base_url}/api/v1/knowledge-bases/batch-import/{task_id}"
- logger.info(f"请求入库任务: url={url}")
- resp = requests.get(url, headers=self._headers(), timeout=15)
- logger.info(f"入库任务响应: url={url}, status={resp.status_code}, body={resp.text[:1000]}")
- return self._parse(resp)
|