| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381 |
- """标注平台 API 客户端服务。
- 对接标注平台的对外 API(HMAC-SHA256 签名认证)。
- 参考文档:标注平台对外API接口文档.md
- 功能:列出项目、获取项目详情、数据集导出与下载。
- """
- import hashlib
- import hmac
- import secrets
- import time
- import uuid
- from datetime import datetime
- from pathlib import Path
- from typing import Any
- import httpx
- from app.config import get_settings
- from app.core.db import async_session, DatasetRecord
- from app.core.logging import logger
- settings = get_settings()
- # Token 缓存(内存中)
- _token_cache: dict[str, Any] = {}
- def _get_base_url() -> str:
- base_url = settings.annotation_platform_base_url
- if not base_url:
- raise ValueError("标注平台地址未配置,请检查 ANNOTATION_PLATFORM_BASE_URL")
- return base_url.rstrip("/")
- def _get_credentials() -> tuple[str, str]:
- app_id = settings.annotation_platform_app_id
- app_secret = settings.annotation_platform_app_secret
- if not app_id or not app_secret:
- raise ValueError("标注平台凭证未配置,请检查 ANNOTATION_PLATFORM_APP_ID 和 ANNOTATION_PLATFORM_APP_SECRET")
- return app_id, app_secret
- def _build_token_headers() -> dict[str, str]:
- """构建获取 Token 的 HMAC-SHA256 签名请求头。"""
- app_id, app_secret = _get_credentials()
- timestamp = str(int(time.time()))
- nonce = uuid.uuid4().hex # 使用 uuid4 保证每次唯一,16+ 位随机字符串
- message = app_id + timestamp + nonce
- signature = hmac.new(
- key=app_secret.encode("utf-8"),
- msg=message.encode("utf-8"),
- digestmod=hashlib.sha256,
- ).hexdigest()
- return {
- "Content-Type": "application/json",
- "X-Api-Key": app_id,
- "X-Timestamp": timestamp,
- "X-Nonce": nonce,
- "X-Signature": signature,
- }
- def _is_token_valid() -> bool:
- """检查缓存的 Token 是否仍然有效(提前 5 分钟刷新)。"""
- if not _token_cache.get("access_token"):
- return False
- expires_at = _token_cache.get("expires_at", 0)
- return time.time() < expires_at - 300
- async def _refresh_token() -> str:
- """使用 Bearer Token 刷新 Access Token。
- POST /api/v1/open/auth/refresh
- 比重新签名更高效,仅在 Token 存在但即将过期时调用。
- """
- old_token = _token_cache.get("access_token", "")
- base_url = _get_base_url()
- async with httpx.AsyncClient(timeout=30) as client:
- resp = await client.post(
- f"{base_url}/api/v1/open/auth/refresh",
- headers={"Authorization": f"Bearer {old_token}"},
- )
- resp.raise_for_status()
- body = resp.json()
- if body.get("code") != 0:
- raise RuntimeError(f"刷新标注平台 Token 失败: {body.get('message', body)}")
- data = body.get("data", {})
- _token_cache["access_token"] = data["access_token"]
- _token_cache["expires_in"] = data.get("expires_in", 7200)
- _token_cache["expires_at"] = time.time() + data.get("expires_in", 7200)
- return data["access_token"]
- async def get_token() -> str:
- """获取 Access Token,带缓存和自动刷新。
- 优先级:
- 1. Token 有效 → 直接返回
- 2. Token 存在但快过期 → 调用 /auth/refresh 刷新
- 3. 无 Token → 调用 /auth/token 重新签名获取
- """
- if _is_token_valid():
- return _token_cache["access_token"]
- # Token 存在但即将过期,尝试刷新
- if _token_cache.get("access_token"):
- try:
- return await _refresh_token()
- except Exception as e:
- logger.warning(f"Token 刷新失败,回退到重新获取: {e}")
- # 无 Token 或刷新失败,重新签名获取
- headers = _build_token_headers()
- base_url = _get_base_url()
- async with httpx.AsyncClient(timeout=30) as client:
- resp = await client.post(
- f"{base_url}/api/v1/open/auth/token",
- json={},
- headers=headers,
- )
- resp.raise_for_status()
- body = resp.json()
- # 标注平台返回 code: 0 表示成功
- if body.get("code") != 0:
- raise RuntimeError(f"获取标注平台 Token 失败: {body.get('message', body)}")
- data = body.get("data", {})
- _token_cache["access_token"] = data["access_token"]
- _token_cache["expires_in"] = data.get("expires_in", 7200)
- _token_cache["expires_at"] = time.time() + data.get("expires_in", 7200)
- return data["access_token"]
- def _auth_headers() -> dict[str, str]:
- """构建业务接口的认证请求头。"""
- return {
- "Authorization": f"Bearer {_token_cache.get('access_token', '')}",
- "Content-Type": "application/json",
- }
- async def _request(method: str, path: str, **kwargs) -> dict[str, Any]:
- """统一的业务请求方法,自动携带 Token。"""
- await get_token()
- base_url = _get_base_url()
- async with httpx.AsyncClient(timeout=60) as client:
- resp = await client.request(
- method,
- f"{base_url}{path}",
- headers=_auth_headers(),
- **kwargs,
- )
- resp.raise_for_status()
- body = resp.json()
- if body.get("code") != 0:
- raise RuntimeError(f"标注平台请求失败: {body.get('message', body)}")
- return body.get("data", {})
- # ---------- 项目列表 ----------
- async def list_projects(
- page: int = 1,
- page_size: int = 20,
- name: str | None = None,
- project_type: str | None = None,
- status: str | None = None,
- ) -> dict[str, Any]:
- """获取标注平台项目列表。
- GET /api/v1/open/projects
- """
- params: dict[str, Any] = {"page": page, "page_size": page_size}
- if name:
- params["name"] = name
- if project_type:
- params["type"] = project_type
- if status:
- params["status"] = status
- return await _request("GET", "/api/v1/open/projects", params=params)
- # ---------- 项目详情 ----------
- async def get_project_detail(project_id: str) -> dict[str, Any]:
- """获取项目详情。
- GET /api/v1/open/projects/{project_id}
- """
- return await _request("GET", f"/api/v1/open/projects/{project_id}")
- # ---------- 数据集导出与下载 ----------
- async def import_project_dataset(
- project_id: str,
- project_name: str = "",
- format: str = "alpaca",
- ) -> dict[str, Any]:
- """导出并下载项目数据集,保存到本地并写入数据库。
- 流程:
- 1. POST /api/v1/open/projects/{project_id}/datasets/download → 获取 file_url
- 2. GET /api/v1/open/datasets/downloads/{download_token} → 下载文件
- 3. 保存到 uploads 目录
- 4. 写入 DatasetRecord 数据库
- """
- # 1. 请求导出
- export_data = await _request(
- "POST",
- f"/api/v1/open/projects/{project_id}/datasets/download",
- json={"format": format, "completed_only": True},
- )
- file_url = export_data.get("file_url", "")
- file_name = export_data.get("file_name", f"{project_id}_{format}.json")
- total_exported = export_data.get("total_exported", 0)
- if not file_url:
- raise RuntimeError("标注平台未返回下载链接")
- # 2. 从 file_url 中提取 download_token
- # file_url 格式如: /api/v1/open/datasets/downloads/dl_abc123
- if "/datasets/downloads/" in file_url:
- download_token = file_url.split("/datasets/downloads/")[-1].strip("/")
- else:
- # 兜底:直接使用 file_url 的最后一段
- download_token = file_url.rstrip("/").split("/")[-1]
- # 3. 通过独立的下载接口获取文件(文档 4.6 节)
- await get_token()
- base_url = _get_base_url()
- async with httpx.AsyncClient(timeout=120) as client:
- resp = await client.get(
- f"{base_url}/api/v1/open/datasets/downloads/{download_token}",
- headers=_auth_headers(),
- follow_redirects=True,
- )
- resp.raise_for_status()
- file_content = resp.content
- # 4. 保存到 uploads 目录
- upload_dir = settings.uploads_dir
- upload_dir.mkdir(parents=True, exist_ok=True)
- safe_name = f"{project_name or project_id}_{file_name}" if project_name else file_name
- # 清理文件名中的非法字符
- safe_name = "".join(c if c.isalnum() or c in "._-" else "_" for c in safe_name)
- file_path = upload_dir / safe_name
- if file_path.exists():
- file_path = upload_dir / f"{uuid.uuid4().hex[:8]}_{safe_name}"
- file_path.write_bytes(file_content)
- # 5. 统一转为 JSONL 格式(和 ModelScope/HF 下载的数据格式一致)
- jsonl_path = _convert_to_jsonl(file_path)
- record_count = _count_records(jsonl_path, "jsonl")
- # 6. 写入数据库(格式统一为 jsonl)
- record_id = str(uuid.uuid4())
- record = DatasetRecord(
- id=record_id,
- name=jsonl_path.name,
- format="jsonl",
- record_count=record_count,
- file_path=str(jsonl_path),
- created_at=datetime.utcnow(),
- )
- async with async_session() as session:
- session.add(record)
- await session.commit()
- logger.info(f"Imported dataset from annotation platform: {project_id} -> {jsonl_path.name} ({record_count} records)")
- return {
- "project_id": project_id,
- "project_name": project_name or project_id,
- "format": "jsonl",
- "total_exported": total_exported,
- "dataset_id": record_id,
- "dataset_name": jsonl_path.name,
- }
- def _convert_to_jsonl(file_path: Path) -> Path:
- """将 JSON/JSONL 文件统一转为 JSONL 格式。"""
- import json as _json
- jsonl_path = file_path.with_suffix(".jsonl")
- with open(file_path, "r", encoding="utf-8") as f:
- content = f.read().strip()
- if not content:
- jsonl_path.touch()
- return jsonl_path
- try:
- # 尝试作为 JSON 数组解析
- data = _json.loads(content)
- if isinstance(data, list):
- with open(jsonl_path, "w", encoding="utf-8") as out:
- for item in data:
- out.write(_json.dumps(item, ensure_ascii=False) + "\n")
- # 删除原始 JSON 文件
- file_path.unlink()
- return jsonl_path
- except _json.JSONDecodeError:
- pass
- # 不是标准 JSON,可能是 JSONL,逐行验证
- lines = content.split("\n")
- valid_lines = []
- for line in lines:
- line = line.strip()
- if line:
- try:
- _json.loads(line)
- valid_lines.append(line)
- except _json.JSONDecodeError:
- continue # 跳过无效行
- with open(jsonl_path, "w", encoding="utf-8") as out:
- out.write("\n".join(valid_lines) + ("\n" if valid_lines else ""))
- file_path.unlink()
- return jsonl_path
- def _detect_format(filename: str) -> str:
- """根据文件名推断格式。"""
- name = filename.lower()
- if name.endswith(".jsonl"):
- return "jsonl"
- if name.endswith(".csv"):
- return "csv"
- if name.endswith(".parquet"):
- return "parquet"
- if name.endswith(".json"):
- return "json"
- return "json"
- def _count_records(file_path: Path, fmt: str) -> int:
- """计算文件中的记录数。"""
- import json
- if not file_path.exists():
- return 0
- try:
- if fmt == "jsonl":
- with open(file_path, "r", encoding="utf-8") as f:
- return sum(1 for line in f if line.strip())
- elif fmt == "json":
- with open(file_path, "r", encoding="utf-8") as f:
- data = json.load(f)
- if isinstance(data, list):
- return len(data)
- return 1
- elif fmt == "csv":
- import csv
- with open(file_path, "r", encoding="utf-8") as f:
- return sum(1 for _ in csv.DictReader(f))
- except Exception:
- return 0
- return 0
|