"""标注平台 API 客户端服务。 对接标注平台的对外 API(HMAC-SHA256 签名认证)。 参考文档:标注平台对外API接口文档.md 功能:列出项目、获取项目详情、数据集导出与下载。 """ import asyncio import hashlib import hmac import json import time import uuid from datetime import datetime from pathlib import Path from typing import Any import httpx from app.config import get_settings from app.core.db import async_session, DatasetRecord from app.core.logging import logger settings = get_settings() # Token 缓存(内存中) _token_cache: dict[str, Any] = {} def _get_base_url() -> str: base_url = settings.annotation_platform_base_url if not base_url: raise ValueError("标注平台地址未配置,请检查 ANNOTATION_PLATFORM_BASE_URL") return base_url.rstrip("/") def _get_credentials() -> tuple[str, str]: app_id = settings.annotation_platform_app_id app_secret = settings.annotation_platform_app_secret if not app_id or not app_secret: raise ValueError("标注平台凭证未配置,请检查 ANNOTATION_PLATFORM_APP_ID 和 ANNOTATION_PLATFORM_APP_SECRET") return app_id, app_secret def _build_token_headers() -> dict[str, str]: """构建获取 Token 的 HMAC-SHA256 签名请求头。""" app_id, app_secret = _get_credentials() timestamp = str(int(time.time())) nonce = uuid.uuid4().hex message = app_id + timestamp + nonce signature = hmac.new( key=app_secret.encode("utf-8"), msg=message.encode("utf-8"), digestmod=hashlib.sha256, ).hexdigest() return { "Content-Type": "application/json", "X-Api-Key": app_id, "X-Timestamp": timestamp, "X-Nonce": nonce, "X-Signature": signature, } def _is_token_valid() -> bool: """检查缓存的 Token 是否仍然有效(提前 5 分钟刷新)。""" if not _token_cache.get("access_token"): return False expires_at = _token_cache.get("expires_at", 0) return time.time() < expires_at - 300 async def _refresh_token() -> str: """使用 Bearer Token 刷新 Access Token。""" old_token = _token_cache.get("access_token", "") base_url = _get_base_url() async with httpx.AsyncClient(timeout=30) as client: resp = await client.post( f"{base_url}/api/v1/open/auth/refresh", headers={"Authorization": f"Bearer {old_token}"}, ) resp.raise_for_status() body = resp.json() if body.get("code") != 0: raise RuntimeError(f"刷新标注平台 Token 失败: {body.get('message', body)}") data = body.get("data", {}) _token_cache["access_token"] = data["access_token"] _token_cache["expires_at"] = time.time() + data.get("expires_in", 7200) return data["access_token"] async def get_token() -> str: """获取 Access Token,带缓存和自动刷新。""" if _is_token_valid(): return _token_cache["access_token"] # Token 存在但即将过期,尝试刷新 if _token_cache.get("access_token"): try: return await _refresh_token() except Exception as e: logger.warning(f"Token 刷新失败,回退到重新获取: {e}") # 无 Token 或刷新失败,重新签名获取 headers = _build_token_headers() base_url = _get_base_url() async with httpx.AsyncClient(timeout=30) as client: resp = await client.post( f"{base_url}/api/v1/open/auth/token", json={}, headers=headers, ) resp.raise_for_status() body = resp.json() if body.get("code") != 0: raise RuntimeError(f"获取标注平台 Token 失败: {body.get('message', body)}") data = body.get("data", {}) _token_cache["access_token"] = data["access_token"] _token_cache["expires_at"] = time.time() + data.get("expires_in", 7200) return data["access_token"] def _auth_headers() -> dict[str, str]: """构建业务接口的认证请求头。""" return { "Authorization": f"Bearer {_token_cache.get('access_token', '')}", "Content-Type": "application/json", } async def _request( method: str, path: str, *, raise_on_error: bool = True, **kwargs ) -> dict[str, Any]: """统一的业务请求方法,自动携带 Token。 raise_on_error=False 时,HTTP 4xx/5xx 不抛异常,返回带 _error 标记的 dict。 """ await get_token() base_url = _get_base_url() async with httpx.AsyncClient(timeout=60) as client: resp = await client.request( method, f"{base_url}{path}", headers=_auth_headers(), **kwargs, ) if not raise_on_error and resp.status_code >= 400: try: body = resp.json() except Exception: body = {"message": resp.text} return {"_error": True, "_status_code": resp.status_code, **body} resp.raise_for_status() body = resp.json() if body.get("code") != 0: raise RuntimeError(f"标注平台请求失败: {body.get('message', body)}") return body.get("data", {}) # ---------- 项目列表 ---------- async def list_projects( page: int = 1, page_size: int = 20, name: str | None = None, project_type: str | None = None, status: str | None = None, ) -> dict[str, Any]: """获取标注平台项目列表。""" params: dict[str, Any] = {"page": page, "page_size": page_size} if name: params["name"] = name if project_type: params["type"] = project_type if status: params["status"] = status return await _request("GET", "/api/v1/open/projects", params=params) # ---------- 项目详情 ---------- async def get_project_detail(project_id: str) -> dict[str, Any]: """获取项目详情。""" return await _request("GET", f"/api/v1/open/projects/{project_id}") # ---------- 数据集导出与下载 ---------- # 文本项目和图片项目支持的导出格式(参考标注平台 API 文档) _TEXT_FORMATS = ["alpaca", "sharegpt"] _IMAGE_FORMATS = ["json", "csv", "coco"] async def _request_export(project_id: str, fmt: str) -> dict[str, Any] | None: """请求导出并返回导出信息,格式不兼容或失败时返回 None。""" data = await _request( "POST", f"/api/v1/open/projects/{project_id}/datasets/download", raise_on_error=False, json={"format": fmt, "completed_only": True}, ) if data.get("_error"): logger.info(f"导出格式 '{fmt}' 不可用: {data.get('message', data.get('detail', ''))}") return None if data.get("status") != "completed": logger.info(f"导出格式 '{fmt}' 状态异常: {data.get('status')}") return None if not data.get("file_url"): logger.info(f"导出格式 '{fmt}' 未返回下载链接") return None return data async def _download_file(download_token: str, max_retries: int = 3) -> bytes: """通过 download_token 下载导出文件,带重试(标注平台可能需要时间生成文件)。""" await get_token() base_url = _get_base_url() url = f"{base_url}/api/v1/open/datasets/downloads/{download_token}" for attempt in range(max_retries): async with httpx.AsyncClient(timeout=120) as client: resp = await client.get(url, headers=_auth_headers(), follow_redirects=True) resp.raise_for_status() content = resp.content if len(content) > 10: return content if attempt < max_retries - 1: wait = 2 ** attempt # 1, 2 秒 logger.info(f"文件尚未就绪 ({len(content)} bytes),{wait}s 后重试...") await asyncio.sleep(wait) return content def _extract_download_token(file_url: str) -> str: """从 file_url 中提取 download_token。""" if "/datasets/downloads/" in file_url: return file_url.split("/datasets/downloads/")[-1].strip("/") return file_url.rstrip("/").split("/")[-1] async def import_project_dataset( project_id: str, project_name: str = "", format: str = "alpaca", ) -> dict[str, Any]: """导出并下载项目数据集,保存到本地并写入数据库。 流程: 1. 查询项目详情获取 project_type 2. 根据项目类型选择合适的导出格式,依次尝试 3. 下载文件并转换为 JSONL 4. 写入数据库 """ # 1. 查询项目类型,决定可用格式 try: detail = await get_project_detail(project_id) project_type = detail.get("project_type", "text") except Exception: project_type = "text" # 2. 构建格式尝试列表(用户指定的格式优先) if project_type == "text": formats_to_try = list(_TEXT_FORMATS) else: formats_to_try = list(_IMAGE_FORMATS) if format and format not in formats_to_try: formats_to_try.insert(0, format) # 3. 依次尝试各格式:请求导出 → 下载文件 file_content = b"" total_exported = 0 used_format = "" for fmt in formats_to_try: export_data = await _request_export(project_id, fmt) if not export_data: continue total_exported = export_data.get("total_exported", 0) download_token = _extract_download_token(export_data["file_url"]) file_content = await _download_file(download_token) if len(file_content) > 10: used_format = fmt logger.info( f"标注平台导出成功: format={fmt}, {len(file_content)} bytes, " f"total_exported={total_exported}" ) break logger.info(f"格式 '{fmt}' 导出文件为空 ({len(file_content)} bytes),尝试下一格式") if len(file_content) <= 10: raise RuntimeError( f"标注平台所有导出格式均返回空文件(project_type={project_type}," f"尝试格式: {formats_to_try}),请检查标注平台该项目的数据是否支持导出" ) # 4. 保存到 uploads 目录 upload_dir = settings.uploads_dir upload_dir.mkdir(parents=True, exist_ok=True) safe_name = "".join(c if c.isalnum() or c in "._-" else "_" for c in (project_name or project_id)) file_path = upload_dir / f"{safe_name}.jsonl" if file_path.exists(): file_path = upload_dir / f"{uuid.uuid4().hex[:8]}_{safe_name}.jsonl" file_path.write_bytes(file_content) # 5. 转为 JSONL 格式 jsonl_path = _convert_to_jsonl(file_path) record_count = _count_records(jsonl_path, "jsonl") # 6. 写入数据库 record_id = str(uuid.uuid4()) record = DatasetRecord( id=record_id, name=project_name or project_id, format="jsonl", record_count=record_count, file_path=str(jsonl_path), created_at=datetime.utcnow(), ) async with async_session() as session: session.add(record) await session.commit() logger.info(f"Imported annotation dataset: {project_name} ({record_count} records, format={used_format})") return { "project_id": project_id, "project_name": project_name or project_id, "format": "jsonl", "total_exported": total_exported, "dataset_id": record_id, "dataset_name": jsonl_path.name, } def _convert_to_jsonl(file_path: Path) -> Path: """将 JSON/JSONL 文件统一转为 JSONL 格式。""" jsonl_path = file_path.with_suffix(".jsonl") with open(file_path, "r", encoding="utf-8") as f: content = f.read().strip() if not content: jsonl_path.touch() return jsonl_path try: data = json.loads(content) items = _extract_items(data) if items is not None: with open(jsonl_path, "w", encoding="utf-8") as out: for item in items: out.write(json.dumps(item, ensure_ascii=False) + "\n") if jsonl_path != file_path and file_path.exists(): file_path.unlink() return jsonl_path except json.JSONDecodeError: pass # JSONL 逐行验证 valid_lines = [line.strip() for line in content.split("\n") if line.strip()] valid_lines = [line for line in valid_lines if _is_valid_json(line)] with open(jsonl_path, "w", encoding="utf-8") as out: out.write("\n".join(valid_lines) + ("\n" if valid_lines else "")) if jsonl_path != file_path and file_path.exists(): file_path.unlink() return jsonl_path def _extract_items(data) -> list | None: """从 JSON 数据中提取记录列表。""" if isinstance(data, list): return data if isinstance(data, dict): for key in ("data", "items", "results", "records", "annotations", "samples"): if key in data and isinstance(data[key], list): return data[key] return [data] return None def _is_valid_json(s: str) -> bool: try: json.loads(s) return True except json.JSONDecodeError: return False def _count_records(file_path: Path, fmt: str) -> int: """计算文件中的记录数。""" if not file_path.exists(): return 0 try: if fmt == "jsonl": with open(file_path, "r", encoding="utf-8") as f: return sum(1 for line in f if line.strip()) elif fmt == "json": with open(file_path, "r", encoding="utf-8") as f: data = json.load(f) return len(data) if isinstance(data, list) else 1 elif fmt == "csv": import csv with open(file_path, "r", encoding="utf-8") as f: return sum(1 for _ in csv.DictReader(f)) except Exception: return 0 return 0