"""标注平台 API 客户端服务。 对接标注平台的对外 API(HMAC-SHA256 签名认证)。 参考文档:标注平台对外API接口文档.md 功能:列出项目、获取项目详情、数据集导出与下载。 """ import hashlib import hmac import secrets import time import uuid from datetime import datetime from pathlib import Path from typing import Any import httpx from app.config import get_settings from app.core.db import async_session, DatasetRecord from app.core.logging import logger settings = get_settings() # Token 缓存(内存中) _token_cache: dict[str, Any] = {} def _get_base_url() -> str: base_url = settings.annotation_platform_base_url if not base_url: raise ValueError("标注平台地址未配置,请检查 ANNOTATION_PLATFORM_BASE_URL") return base_url.rstrip("/") def _get_credentials() -> tuple[str, str]: app_id = settings.annotation_platform_app_id app_secret = settings.annotation_platform_app_secret if not app_id or not app_secret: raise ValueError("标注平台凭证未配置,请检查 ANNOTATION_PLATFORM_APP_ID 和 ANNOTATION_PLATFORM_APP_SECRET") return app_id, app_secret def _build_token_headers() -> dict[str, str]: """构建获取 Token 的 HMAC-SHA256 签名请求头。""" app_id, app_secret = _get_credentials() timestamp = str(int(time.time())) nonce = uuid.uuid4().hex # 使用 uuid4 保证每次唯一,16+ 位随机字符串 message = app_id + timestamp + nonce signature = hmac.new( key=app_secret.encode("utf-8"), msg=message.encode("utf-8"), digestmod=hashlib.sha256, ).hexdigest() return { "Content-Type": "application/json", "X-Api-Key": app_id, "X-Timestamp": timestamp, "X-Nonce": nonce, "X-Signature": signature, } def _is_token_valid() -> bool: """检查缓存的 Token 是否仍然有效(提前 5 分钟刷新)。""" if not _token_cache.get("access_token"): return False expires_at = _token_cache.get("expires_at", 0) return time.time() < expires_at - 300 async def _refresh_token() -> str: """使用 Bearer Token 刷新 Access Token。 POST /api/v1/open/auth/refresh 比重新签名更高效,仅在 Token 存在但即将过期时调用。 """ old_token = _token_cache.get("access_token", "") base_url = _get_base_url() async with httpx.AsyncClient(timeout=30) as client: resp = await client.post( f"{base_url}/api/v1/open/auth/refresh", headers={"Authorization": f"Bearer {old_token}"}, ) resp.raise_for_status() body = resp.json() if body.get("code") != 0: raise RuntimeError(f"刷新标注平台 Token 失败: {body.get('message', body)}") data = body.get("data", {}) _token_cache["access_token"] = data["access_token"] _token_cache["expires_in"] = data.get("expires_in", 7200) _token_cache["expires_at"] = time.time() + data.get("expires_in", 7200) return data["access_token"] async def get_token() -> str: """获取 Access Token,带缓存和自动刷新。 优先级: 1. Token 有效 → 直接返回 2. Token 存在但快过期 → 调用 /auth/refresh 刷新 3. 无 Token → 调用 /auth/token 重新签名获取 """ if _is_token_valid(): return _token_cache["access_token"] # Token 存在但即将过期,尝试刷新 if _token_cache.get("access_token"): try: return await _refresh_token() except Exception as e: logger.warning(f"Token 刷新失败,回退到重新获取: {e}") # 无 Token 或刷新失败,重新签名获取 headers = _build_token_headers() base_url = _get_base_url() async with httpx.AsyncClient(timeout=30) as client: resp = await client.post( f"{base_url}/api/v1/open/auth/token", json={}, headers=headers, ) resp.raise_for_status() body = resp.json() # 标注平台返回 code: 0 表示成功 if body.get("code") != 0: raise RuntimeError(f"获取标注平台 Token 失败: {body.get('message', body)}") data = body.get("data", {}) _token_cache["access_token"] = data["access_token"] _token_cache["expires_in"] = data.get("expires_in", 7200) _token_cache["expires_at"] = time.time() + data.get("expires_in", 7200) return data["access_token"] def _auth_headers() -> dict[str, str]: """构建业务接口的认证请求头。""" return { "Authorization": f"Bearer {_token_cache.get('access_token', '')}", "Content-Type": "application/json", } async def _request(method: str, path: str, **kwargs) -> dict[str, Any]: """统一的业务请求方法,自动携带 Token。""" await get_token() base_url = _get_base_url() async with httpx.AsyncClient(timeout=60) as client: resp = await client.request( method, f"{base_url}{path}", headers=_auth_headers(), **kwargs, ) resp.raise_for_status() body = resp.json() if body.get("code") != 0: raise RuntimeError(f"标注平台请求失败: {body.get('message', body)}") return body.get("data", {}) # ---------- 项目列表 ---------- async def list_projects( page: int = 1, page_size: int = 20, name: str | None = None, project_type: str | None = None, status: str | None = None, ) -> dict[str, Any]: """获取标注平台项目列表。 GET /api/v1/open/projects """ params: dict[str, Any] = {"page": page, "page_size": page_size} if name: params["name"] = name if project_type: params["type"] = project_type if status: params["status"] = status return await _request("GET", "/api/v1/open/projects", params=params) # ---------- 项目详情 ---------- async def get_project_detail(project_id: str) -> dict[str, Any]: """获取项目详情。 GET /api/v1/open/projects/{project_id} """ return await _request("GET", f"/api/v1/open/projects/{project_id}") # ---------- 数据集导出与下载 ---------- async def import_project_dataset( project_id: str, project_name: str = "", format: str = "alpaca", ) -> dict[str, Any]: """导出并下载项目数据集,保存到本地并写入数据库。 流程: 1. POST /api/v1/open/projects/{project_id}/datasets/download → 获取 file_url 2. GET /api/v1/open/datasets/downloads/{download_token} → 下载文件 3. 保存到 uploads 目录 4. 写入 DatasetRecord 数据库 """ # 1. 请求导出 export_data = await _request( "POST", f"/api/v1/open/projects/{project_id}/datasets/download", json={"format": format, "completed_only": True}, ) file_url = export_data.get("file_url", "") file_name = export_data.get("file_name", f"{project_id}_{format}.json") total_exported = export_data.get("total_exported", 0) if not file_url: raise RuntimeError("标注平台未返回下载链接") # 2. 从 file_url 中提取 download_token # file_url 格式如: /api/v1/open/datasets/downloads/dl_abc123 if "/datasets/downloads/" in file_url: download_token = file_url.split("/datasets/downloads/")[-1].strip("/") else: # 兜底:直接使用 file_url 的最后一段 download_token = file_url.rstrip("/").split("/")[-1] # 3. 通过独立的下载接口获取文件(文档 4.6 节) await get_token() base_url = _get_base_url() async with httpx.AsyncClient(timeout=120) as client: resp = await client.get( f"{base_url}/api/v1/open/datasets/downloads/{download_token}", headers=_auth_headers(), follow_redirects=True, ) resp.raise_for_status() file_content = resp.content # 4. 保存到 uploads 目录 upload_dir = settings.uploads_dir upload_dir.mkdir(parents=True, exist_ok=True) safe_name = f"{project_name or project_id}_{file_name}" if project_name else file_name # 清理文件名中的非法字符 safe_name = "".join(c if c.isalnum() or c in "._-" else "_" for c in safe_name) file_path = upload_dir / safe_name if file_path.exists(): file_path = upload_dir / f"{uuid.uuid4().hex[:8]}_{safe_name}" file_path.write_bytes(file_content) # 5. 统一转为 JSONL 格式(和 ModelScope/HF 下载的数据格式一致) jsonl_path = _convert_to_jsonl(file_path) record_count = _count_records(jsonl_path, "jsonl") # 6. 写入数据库(格式统一为 jsonl) record_id = str(uuid.uuid4()) record = DatasetRecord( id=record_id, name=jsonl_path.name, format="jsonl", record_count=record_count, file_path=str(jsonl_path), created_at=datetime.utcnow(), ) async with async_session() as session: session.add(record) await session.commit() logger.info(f"Imported dataset from annotation platform: {project_id} -> {jsonl_path.name} ({record_count} records)") return { "project_id": project_id, "project_name": project_name or project_id, "format": "jsonl", "total_exported": total_exported, "dataset_id": record_id, "dataset_name": jsonl_path.name, } def _convert_to_jsonl(file_path: Path) -> Path: """将 JSON/JSONL 文件统一转为 JSONL 格式。""" import json as _json jsonl_path = file_path.with_suffix(".jsonl") with open(file_path, "r", encoding="utf-8") as f: content = f.read().strip() if not content: jsonl_path.touch() return jsonl_path try: # 尝试作为 JSON 数组解析 data = _json.loads(content) if isinstance(data, list): with open(jsonl_path, "w", encoding="utf-8") as out: for item in data: out.write(_json.dumps(item, ensure_ascii=False) + "\n") # 删除原始 JSON 文件 file_path.unlink() return jsonl_path except _json.JSONDecodeError: pass # 不是标准 JSON,可能是 JSONL,逐行验证 lines = content.split("\n") valid_lines = [] for line in lines: line = line.strip() if line: try: _json.loads(line) valid_lines.append(line) except _json.JSONDecodeError: continue # 跳过无效行 with open(jsonl_path, "w", encoding="utf-8") as out: out.write("\n".join(valid_lines) + ("\n" if valid_lines else "")) file_path.unlink() return jsonl_path def _detect_format(filename: str) -> str: """根据文件名推断格式。""" name = filename.lower() if name.endswith(".jsonl"): return "jsonl" if name.endswith(".csv"): return "csv" if name.endswith(".parquet"): return "parquet" if name.endswith(".json"): return "json" return "json" def _count_records(file_path: Path, fmt: str) -> int: """计算文件中的记录数。""" import json if not file_path.exists(): return 0 try: if fmt == "jsonl": with open(file_path, "r", encoding="utf-8") as f: return sum(1 for line in f if line.strip()) elif fmt == "json": with open(file_path, "r", encoding="utf-8") as f: data = json.load(f) if isinstance(data, list): return len(data) return 1 elif fmt == "csv": import csv with open(file_path, "r", encoding="utf-8") as f: return sum(1 for _ in csv.DictReader(f)) except Exception: return 0 return 0