|
|
@@ -1,7 +1,8 @@
|
|
|
"""标注平台 API 客户端服务。
|
|
|
|
|
|
-对接标注平台的对外 API,支持 HMAC-SHA256 签名认证。
|
|
|
-功能:列出项目、获取项目详情、导出并下载数据集。
|
|
|
+对接标注平台的对外 API(HMAC-SHA256 签名认证)。
|
|
|
+参考文档:标注平台对外API接口文档.md
|
|
|
+功能:列出项目、获取项目详情、数据集导出与下载。
|
|
|
"""
|
|
|
|
|
|
import hashlib
|
|
|
@@ -26,59 +27,74 @@ _token_cache: dict[str, Any] = {}
|
|
|
|
|
|
|
|
|
def _get_base_url() -> str:
|
|
|
- if not settings.annotation_platform_base_url:
|
|
|
- raise ValueError("标注平台地址未配置,请检查 ANNOTATION_PLATFORM_BASE_URL 环境变量")
|
|
|
- return settings.annotation_platform_base_url.rstrip("/")
|
|
|
+ base_url = settings.annotation_platform_base_url
|
|
|
+ if not base_url:
|
|
|
+ raise ValueError("标注平台地址未配置,请检查 ANNOTATION_PLATFORM_BASE_URL")
|
|
|
+ return base_url.rstrip("/")
|
|
|
|
|
|
|
|
|
def _get_credentials() -> tuple[str, str]:
|
|
|
- if not settings.annotation_platform_app_id or not settings.annotation_platform_app_secret:
|
|
|
+ app_id = settings.annotation_platform_app_id
|
|
|
+ app_secret = settings.annotation_platform_app_secret
|
|
|
+ if not app_id or not app_secret:
|
|
|
raise ValueError("标注平台凭证未配置,请检查 ANNOTATION_PLATFORM_APP_ID 和 ANNOTATION_PLATFORM_APP_SECRET")
|
|
|
- return settings.annotation_platform_app_id, settings.annotation_platform_app_secret
|
|
|
+ return app_id, app_secret
|
|
|
|
|
|
|
|
|
-def _sign(app_secret: str, app_id: str, timestamp: str, nonce: str) -> str:
|
|
|
- """HMAC-SHA256 签名。"""
|
|
|
+def _build_token_headers() -> dict[str, str]:
|
|
|
+ """构建获取 Token 的 HMAC-SHA256 签名请求头。"""
|
|
|
+ app_id, app_secret = _get_credentials()
|
|
|
+ timestamp = str(int(time.time()))
|
|
|
+ nonce = uuid.uuid4().hex # 使用 uuid4 保证每次唯一,16+ 位随机字符串
|
|
|
message = app_id + timestamp + nonce
|
|
|
- return hmac.new(app_secret.encode(), message.encode(), hashlib.sha256).hexdigest()
|
|
|
+ signature = hmac.new(
|
|
|
+ key=app_secret.encode("utf-8"),
|
|
|
+ msg=message.encode("utf-8"),
|
|
|
+ digestmod=hashlib.sha256,
|
|
|
+ ).hexdigest()
|
|
|
|
|
|
+ return {
|
|
|
+ "Content-Type": "application/json",
|
|
|
+ "X-Api-Key": app_id,
|
|
|
+ "X-Timestamp": timestamp,
|
|
|
+ "X-Nonce": nonce,
|
|
|
+ "X-Signature": signature,
|
|
|
+ }
|
|
|
|
|
|
-def _check_token_valid() -> bool:
|
|
|
+
|
|
|
+def _is_token_valid() -> bool:
|
|
|
+ """检查缓存的 Token 是否仍然有效(提前 5 分钟刷新)。"""
|
|
|
if not _token_cache.get("access_token"):
|
|
|
return False
|
|
|
expires_at = _token_cache.get("expires_at", 0)
|
|
|
- return time.time() < expires_at - 300 # 提前 5 分钟刷新
|
|
|
+ return time.time() < expires_at - 300
|
|
|
|
|
|
|
|
|
async def get_token() -> str:
|
|
|
- """获取 Access Token,带缓存。"""
|
|
|
- if _check_token_valid():
|
|
|
+ """获取 Access Token,带缓存。
|
|
|
+
|
|
|
+ POST /api/v1/open/auth/token
|
|
|
+ 使用 HMAC-SHA256 签名认证,无请求体。
|
|
|
+ """
|
|
|
+ if _is_token_valid():
|
|
|
return _token_cache["access_token"]
|
|
|
|
|
|
- app_id, app_secret = _get_credentials()
|
|
|
+ headers = _build_token_headers()
|
|
|
base_url = _get_base_url()
|
|
|
|
|
|
- timestamp = str(int(time.time()))
|
|
|
- nonce = secrets.token_hex(8) # 16 位十六进制随机字符串
|
|
|
- signature = _sign(app_secret, app_id, timestamp, nonce)
|
|
|
-
|
|
|
async with httpx.AsyncClient(timeout=30) as client:
|
|
|
resp = await client.post(
|
|
|
f"{base_url}/api/v1/open/auth/token",
|
|
|
- headers={
|
|
|
- "X-Api-Key": app_id,
|
|
|
- "X-Signature": signature,
|
|
|
- "X-Timestamp": timestamp,
|
|
|
- "X-Nonce": nonce,
|
|
|
- },
|
|
|
+ headers=headers,
|
|
|
)
|
|
|
resp.raise_for_status()
|
|
|
body = resp.json()
|
|
|
|
|
|
+ # 标注平台返回 code: 0 表示成功
|
|
|
if body.get("code") != 0:
|
|
|
- raise RuntimeError(f"获取标注平台 Token 失败: {body.get('message')}")
|
|
|
+ raise RuntimeError(f"获取标注平台 Token 失败: {body.get('message', body)}")
|
|
|
|
|
|
- data = body["data"]
|
|
|
+ data = body.get("data", {})
|
|
|
_token_cache["access_token"] = data["access_token"]
|
|
|
_token_cache["expires_in"] = data.get("expires_in", 7200)
|
|
|
_token_cache["expires_at"] = time.time() + data.get("expires_in", 7200)
|
|
|
@@ -87,15 +103,15 @@ async def get_token() -> str:
|
|
|
|
|
|
|
|
|
def _auth_headers() -> dict[str, str]:
|
|
|
- token = _token_cache.get("access_token", "")
|
|
|
+ """构建业务接口的认证请求头。"""
|
|
|
return {
|
|
|
- "Authorization": f"Bearer {token}",
|
|
|
+ "Authorization": f"Bearer {_token_cache.get('access_token', '')}",
|
|
|
"Content-Type": "application/json",
|
|
|
}
|
|
|
|
|
|
|
|
|
async def _request(method: str, path: str, **kwargs) -> dict[str, Any]:
|
|
|
- """统一的请求方法,自动携带 Token 并处理错误。"""
|
|
|
+ """统一的业务请求方法,自动携带 Token。"""
|
|
|
await get_token()
|
|
|
base_url = _get_base_url()
|
|
|
|
|
|
@@ -110,7 +126,7 @@ async def _request(method: str, path: str, **kwargs) -> dict[str, Any]:
|
|
|
body = resp.json()
|
|
|
|
|
|
if body.get("code") != 0:
|
|
|
- raise RuntimeError(f"标注平台请求失败: {body.get('message')}")
|
|
|
+ raise RuntimeError(f"标注平台请求失败: {body.get('message', body)}")
|
|
|
|
|
|
return body.get("data", {})
|
|
|
|
|
|
@@ -124,7 +140,10 @@ async def list_projects(
|
|
|
project_type: str | None = None,
|
|
|
status: str | None = None,
|
|
|
) -> dict[str, Any]:
|
|
|
- """获取标注平台项目列表。"""
|
|
|
+ """获取标注平台项目列表。
|
|
|
+
|
|
|
+ GET /api/v1/open/projects
|
|
|
+ """
|
|
|
params: dict[str, Any] = {"page": page, "page_size": page_size}
|
|
|
if name:
|
|
|
params["name"] = name
|
|
|
@@ -139,7 +158,10 @@ async def list_projects(
|
|
|
# ---------- 项目详情 ----------
|
|
|
|
|
|
async def get_project_detail(project_id: str) -> dict[str, Any]:
|
|
|
- """获取项目详情。"""
|
|
|
+ """获取项目详情。
|
|
|
+
|
|
|
+ GET /api/v1/open/projects/{project_id}
|
|
|
+ """
|
|
|
return await _request("GET", f"/api/v1/open/projects/{project_id}")
|
|
|
|
|
|
|
|
|
@@ -153,9 +175,10 @@ async def import_project_dataset(
|
|
|
"""导出并下载项目数据集,保存到本地并写入数据库。
|
|
|
|
|
|
流程:
|
|
|
- 1. POST 请求导出 → 获取 file_url
|
|
|
- 2. GET 下载文件 → 保存到 uploads 目录
|
|
|
- 3. 写入 DatasetRecord 数据库
|
|
|
+ 1. POST /api/v1/open/projects/{project_id}/datasets/download → 获取 file_url
|
|
|
+ 2. GET /api/v1/open/datasets/downloads/{download_token} → 下载文件
|
|
|
+ 3. 保存到 uploads 目录
|
|
|
+ 4. 写入 DatasetRecord 数据库
|
|
|
"""
|
|
|
# 1. 请求导出
|
|
|
export_data = await _request(
|
|
|
@@ -171,21 +194,28 @@ async def import_project_dataset(
|
|
|
if not file_url:
|
|
|
raise RuntimeError("标注平台未返回下载链接")
|
|
|
|
|
|
- # 2. 下载文件
|
|
|
+ # 2. 从 file_url 中提取 download_token
|
|
|
+ # file_url 格式如: /api/v1/open/datasets/downloads/dl_abc123
|
|
|
+ if "/datasets/downloads/" in file_url:
|
|
|
+ download_token = file_url.split("/datasets/downloads/")[-1].strip("/")
|
|
|
+ else:
|
|
|
+ # 兜底:直接使用 file_url 的最后一段
|
|
|
+ download_token = file_url.rstrip("/").split("/")[-1]
|
|
|
+
|
|
|
+ # 3. 通过独立的下载接口获取文件(文档 4.6 节)
|
|
|
await get_token()
|
|
|
base_url = _get_base_url()
|
|
|
- download_url = f"{base_url}{file_url}" if file_url.startswith("/") else file_url
|
|
|
|
|
|
async with httpx.AsyncClient(timeout=120) as client:
|
|
|
resp = await client.get(
|
|
|
- download_url,
|
|
|
- headers={"Authorization": f"Bearer {_token_cache.get('access_token', '')}"},
|
|
|
+ f"{base_url}/api/v1/open/datasets/downloads/{download_token}",
|
|
|
+ headers=_auth_headers(),
|
|
|
follow_redirects=True,
|
|
|
)
|
|
|
resp.raise_for_status()
|
|
|
file_content = resp.content
|
|
|
|
|
|
- # 3. 保存到 uploads 目录
|
|
|
+ # 4. 保存到 uploads 目录
|
|
|
upload_dir = settings.uploads_dir
|
|
|
upload_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
@@ -198,11 +228,11 @@ async def import_project_dataset(
|
|
|
|
|
|
file_path.write_bytes(file_content)
|
|
|
|
|
|
- # 4. 检测格式和记录数
|
|
|
+ # 5. 检测格式和记录数
|
|
|
fmt = _detect_format(file_path.name)
|
|
|
record_count = _count_records(file_path, fmt)
|
|
|
|
|
|
- # 5. 写入数据库
|
|
|
+ # 6. 写入数据库
|
|
|
record_id = str(uuid.uuid4())
|
|
|
record = DatasetRecord(
|
|
|
id=record_id,
|