"""标注平台 API 客户端服务。 对接标注平台的对外 API,支持 HMAC-SHA256 签名认证。 功能:列出项目、获取项目详情、导出并下载数据集。 """ import hashlib import hmac import secrets import time import uuid from datetime import datetime from pathlib import Path from typing import Any import httpx from app.config import get_settings from app.core.db import async_session, DatasetRecord from app.core.logging import logger settings = get_settings() # Token 缓存(内存中) _token_cache: dict[str, Any] = {} def _get_base_url() -> str: if not settings.annotation_platform_base_url: raise ValueError("标注平台地址未配置,请检查 ANNOTATION_PLATFORM_BASE_URL 环境变量") return settings.annotation_platform_base_url.rstrip("/") def _get_credentials() -> tuple[str, str]: if not settings.annotation_platform_app_id or not settings.annotation_platform_app_secret: raise ValueError("标注平台凭证未配置,请检查 ANNOTATION_PLATFORM_APP_ID 和 ANNOTATION_PLATFORM_APP_SECRET") return settings.annotation_platform_app_id, settings.annotation_platform_app_secret def _sign(app_secret: str, app_id: str, timestamp: str, nonce: str) -> str: """HMAC-SHA256 签名。""" message = app_id + timestamp + nonce return hmac.new(app_secret.encode(), message.encode(), hashlib.sha256).hexdigest() def _check_token_valid() -> bool: if not _token_cache.get("access_token"): return False expires_at = _token_cache.get("expires_at", 0) return time.time() < expires_at - 300 # 提前 5 分钟刷新 async def get_token() -> str: """获取 Access Token,带缓存。""" if _check_token_valid(): return _token_cache["access_token"] app_id, app_secret = _get_credentials() base_url = _get_base_url() timestamp = str(int(time.time())) nonce = secrets.token_hex(8) # 16 位十六进制随机字符串 signature = _sign(app_secret, app_id, timestamp, nonce) async with httpx.AsyncClient(timeout=30) as client: resp = await client.post( f"{base_url}/api/v1/open/auth/token", headers={ "X-Api-Key": app_id, "X-Signature": signature, "X-Timestamp": timestamp, "X-Nonce": nonce, }, ) resp.raise_for_status() body = resp.json() if body.get("code") != 0: raise RuntimeError(f"获取标注平台 Token 失败: {body.get('message')}") data = body["data"] _token_cache["access_token"] = data["access_token"] _token_cache["expires_in"] = data.get("expires_in", 7200) _token_cache["expires_at"] = time.time() + data.get("expires_in", 7200) return data["access_token"] def _auth_headers() -> dict[str, str]: token = _token_cache.get("access_token", "") return { "Authorization": f"Bearer {token}", "Content-Type": "application/json", } async def _request(method: str, path: str, **kwargs) -> dict[str, Any]: """统一的请求方法,自动携带 Token 并处理错误。""" await get_token() base_url = _get_base_url() async with httpx.AsyncClient(timeout=60) as client: resp = await client.request( method, f"{base_url}{path}", headers=_auth_headers(), **kwargs, ) resp.raise_for_status() body = resp.json() if body.get("code") != 0: raise RuntimeError(f"标注平台请求失败: {body.get('message')}") return body.get("data", {}) # ---------- 项目列表 ---------- async def list_projects( page: int = 1, page_size: int = 20, name: str | None = None, project_type: str | None = None, status: str | None = None, ) -> dict[str, Any]: """获取标注平台项目列表。""" params: dict[str, Any] = {"page": page, "page_size": page_size} if name: params["name"] = name if project_type: params["type"] = project_type if status: params["status"] = status return await _request("GET", "/api/v1/open/projects", params=params) # ---------- 项目详情 ---------- async def get_project_detail(project_id: str) -> dict[str, Any]: """获取项目详情。""" return await _request("GET", f"/api/v1/open/projects/{project_id}") # ---------- 数据集导出与下载 ---------- async def import_project_dataset( project_id: str, project_name: str = "", format: str = "alpaca", ) -> dict[str, Any]: """导出并下载项目数据集,保存到本地并写入数据库。 流程: 1. POST 请求导出 → 获取 file_url 2. GET 下载文件 → 保存到 uploads 目录 3. 写入 DatasetRecord 数据库 """ # 1. 请求导出 export_data = await _request( "POST", f"/api/v1/open/projects/{project_id}/datasets/download", json={"format": format, "completed_only": True}, ) file_url = export_data.get("file_url", "") file_name = export_data.get("file_name", f"{project_id}_{format}.json") total_exported = export_data.get("total_exported", 0) if not file_url: raise RuntimeError("标注平台未返回下载链接") # 2. 下载文件 await get_token() base_url = _get_base_url() download_url = f"{base_url}{file_url}" if file_url.startswith("/") else file_url async with httpx.AsyncClient(timeout=120) as client: resp = await client.get( download_url, headers={"Authorization": f"Bearer {_token_cache.get('access_token', '')}"}, follow_redirects=True, ) resp.raise_for_status() file_content = resp.content # 3. 保存到 uploads 目录 upload_dir = settings.uploads_dir upload_dir.mkdir(parents=True, exist_ok=True) safe_name = f"{project_name or project_id}_{file_name}" if project_name else file_name # 清理文件名中的非法字符 safe_name = "".join(c if c.isalnum() or c in "._-" else "_" for c in safe_name) file_path = upload_dir / safe_name if file_path.exists(): file_path = upload_dir / f"{uuid.uuid4().hex[:8]}_{safe_name}" file_path.write_bytes(file_content) # 4. 检测格式和记录数 fmt = _detect_format(file_path.name) record_count = _count_records(file_path, fmt) # 5. 写入数据库 record_id = str(uuid.uuid4()) record = DatasetRecord( id=record_id, name=file_path.name, format=fmt, record_count=record_count, file_path=str(file_path), created_at=datetime.utcnow(), ) async with async_session() as session: session.add(record) await session.commit() logger.info(f"Imported dataset from annotation platform: {project_id} -> {file_path.name} ({record_count} records)") return { "project_id": project_id, "project_name": project_name or project_id, "format": format, "total_exported": total_exported, "dataset_id": record_id, "dataset_name": file_path.name, } def _detect_format(filename: str) -> str: """根据文件名推断格式。""" name = filename.lower() if name.endswith(".jsonl"): return "jsonl" if name.endswith(".csv"): return "csv" if name.endswith(".parquet"): return "parquet" if name.endswith(".json"): return "json" return "json" def _count_records(file_path: Path, fmt: str) -> int: """计算文件中的记录数。""" import json if not file_path.exists(): return 0 try: if fmt == "jsonl": with open(file_path, "r", encoding="utf-8") as f: return sum(1 for line in f if line.strip()) elif fmt == "json": with open(file_path, "r", encoding="utf-8") as f: data = json.load(f) if isinstance(data, list): return len(data) return 1 elif fmt == "csv": import csv with open(file_path, "r", encoding="utf-8") as f: return sum(1 for _ in csv.DictReader(f)) except Exception: return 0 return 0