| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269 |
- """标注平台 API 客户端服务。
- 对接标注平台的对外 API,支持 HMAC-SHA256 签名认证。
- 功能:列出项目、获取项目详情、导出并下载数据集。
- """
- import hashlib
- import hmac
- import secrets
- import time
- import uuid
- from datetime import datetime
- from pathlib import Path
- from typing import Any
- import httpx
- from app.config import get_settings
- from app.core.db import async_session, DatasetRecord
- from app.core.logging import logger
- settings = get_settings()
- # Token 缓存(内存中)
- _token_cache: dict[str, Any] = {}
- def _get_base_url() -> str:
- if not settings.annotation_platform_base_url:
- raise ValueError("标注平台地址未配置,请检查 ANNOTATION_PLATFORM_BASE_URL 环境变量")
- return settings.annotation_platform_base_url.rstrip("/")
- def _get_credentials() -> tuple[str, str]:
- if not settings.annotation_platform_app_id or not settings.annotation_platform_app_secret:
- raise ValueError("标注平台凭证未配置,请检查 ANNOTATION_PLATFORM_APP_ID 和 ANNOTATION_PLATFORM_APP_SECRET")
- return settings.annotation_platform_app_id, settings.annotation_platform_app_secret
- def _sign(app_secret: str, app_id: str, timestamp: str, nonce: str) -> str:
- """HMAC-SHA256 签名。"""
- message = app_id + timestamp + nonce
- return hmac.new(app_secret.encode(), message.encode(), hashlib.sha256).hexdigest()
- def _check_token_valid() -> bool:
- if not _token_cache.get("access_token"):
- return False
- expires_at = _token_cache.get("expires_at", 0)
- return time.time() < expires_at - 300 # 提前 5 分钟刷新
- async def get_token() -> str:
- """获取 Access Token,带缓存。"""
- if _check_token_valid():
- return _token_cache["access_token"]
- app_id, app_secret = _get_credentials()
- base_url = _get_base_url()
- timestamp = str(int(time.time()))
- nonce = secrets.token_hex(8) # 16 位十六进制随机字符串
- signature = _sign(app_secret, app_id, timestamp, nonce)
- async with httpx.AsyncClient(timeout=30) as client:
- resp = await client.post(
- f"{base_url}/api/v1/open/auth/token",
- headers={
- "X-Api-Key": app_id,
- "X-Signature": signature,
- "X-Timestamp": timestamp,
- "X-Nonce": nonce,
- },
- )
- resp.raise_for_status()
- body = resp.json()
- if body.get("code") != 0:
- raise RuntimeError(f"获取标注平台 Token 失败: {body.get('message')}")
- data = body["data"]
- _token_cache["access_token"] = data["access_token"]
- _token_cache["expires_in"] = data.get("expires_in", 7200)
- _token_cache["expires_at"] = time.time() + data.get("expires_in", 7200)
- return data["access_token"]
- def _auth_headers() -> dict[str, str]:
- token = _token_cache.get("access_token", "")
- return {
- "Authorization": f"Bearer {token}",
- "Content-Type": "application/json",
- }
- async def _request(method: str, path: str, **kwargs) -> dict[str, Any]:
- """统一的请求方法,自动携带 Token 并处理错误。"""
- await get_token()
- base_url = _get_base_url()
- async with httpx.AsyncClient(timeout=60) as client:
- resp = await client.request(
- method,
- f"{base_url}{path}",
- headers=_auth_headers(),
- **kwargs,
- )
- resp.raise_for_status()
- body = resp.json()
- if body.get("code") != 0:
- raise RuntimeError(f"标注平台请求失败: {body.get('message')}")
- return body.get("data", {})
- # ---------- 项目列表 ----------
- async def list_projects(
- page: int = 1,
- page_size: int = 20,
- name: str | None = None,
- project_type: str | None = None,
- status: str | None = None,
- ) -> dict[str, Any]:
- """获取标注平台项目列表。"""
- params: dict[str, Any] = {"page": page, "page_size": page_size}
- if name:
- params["name"] = name
- if project_type:
- params["type"] = project_type
- if status:
- params["status"] = status
- return await _request("GET", "/api/v1/open/projects", params=params)
- # ---------- 项目详情 ----------
- async def get_project_detail(project_id: str) -> dict[str, Any]:
- """获取项目详情。"""
- return await _request("GET", f"/api/v1/open/projects/{project_id}")
- # ---------- 数据集导出与下载 ----------
- async def import_project_dataset(
- project_id: str,
- project_name: str = "",
- format: str = "alpaca",
- ) -> dict[str, Any]:
- """导出并下载项目数据集,保存到本地并写入数据库。
- 流程:
- 1. POST 请求导出 → 获取 file_url
- 2. GET 下载文件 → 保存到 uploads 目录
- 3. 写入 DatasetRecord 数据库
- """
- # 1. 请求导出
- export_data = await _request(
- "POST",
- f"/api/v1/open/projects/{project_id}/datasets/download",
- json={"format": format, "completed_only": True},
- )
- file_url = export_data.get("file_url", "")
- file_name = export_data.get("file_name", f"{project_id}_{format}.json")
- total_exported = export_data.get("total_exported", 0)
- if not file_url:
- raise RuntimeError("标注平台未返回下载链接")
- # 2. 下载文件
- await get_token()
- base_url = _get_base_url()
- download_url = f"{base_url}{file_url}" if file_url.startswith("/") else file_url
- async with httpx.AsyncClient(timeout=120) as client:
- resp = await client.get(
- download_url,
- headers={"Authorization": f"Bearer {_token_cache.get('access_token', '')}"},
- follow_redirects=True,
- )
- resp.raise_for_status()
- file_content = resp.content
- # 3. 保存到 uploads 目录
- upload_dir = settings.uploads_dir
- upload_dir.mkdir(parents=True, exist_ok=True)
- safe_name = f"{project_name or project_id}_{file_name}" if project_name else file_name
- # 清理文件名中的非法字符
- safe_name = "".join(c if c.isalnum() or c in "._-" else "_" for c in safe_name)
- file_path = upload_dir / safe_name
- if file_path.exists():
- file_path = upload_dir / f"{uuid.uuid4().hex[:8]}_{safe_name}"
- file_path.write_bytes(file_content)
- # 4. 检测格式和记录数
- fmt = _detect_format(file_path.name)
- record_count = _count_records(file_path, fmt)
- # 5. 写入数据库
- record_id = str(uuid.uuid4())
- record = DatasetRecord(
- id=record_id,
- name=file_path.name,
- format=fmt,
- record_count=record_count,
- file_path=str(file_path),
- created_at=datetime.utcnow(),
- )
- async with async_session() as session:
- session.add(record)
- await session.commit()
- logger.info(f"Imported dataset from annotation platform: {project_id} -> {file_path.name} ({record_count} records)")
- return {
- "project_id": project_id,
- "project_name": project_name or project_id,
- "format": format,
- "total_exported": total_exported,
- "dataset_id": record_id,
- "dataset_name": file_path.name,
- }
- def _detect_format(filename: str) -> str:
- """根据文件名推断格式。"""
- name = filename.lower()
- if name.endswith(".jsonl"):
- return "jsonl"
- if name.endswith(".csv"):
- return "csv"
- if name.endswith(".parquet"):
- return "parquet"
- if name.endswith(".json"):
- return "json"
- return "json"
- def _count_records(file_path: Path, fmt: str) -> int:
- """计算文件中的记录数。"""
- import json
- if not file_path.exists():
- return 0
- try:
- if fmt == "jsonl":
- with open(file_path, "r", encoding="utf-8") as f:
- return sum(1 for line in f if line.strip())
- elif fmt == "json":
- with open(file_path, "r", encoding="utf-8") as f:
- data = json.load(f)
- if isinstance(data, list):
- return len(data)
- return 1
- elif fmt == "csv":
- import csv
- with open(file_path, "r", encoding="utf-8") as f:
- return sum(1 for _ in csv.DictReader(f))
- except Exception:
- return 0
- return 0
|