| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507 |
- """标注平台 API 客户端服务。
- 对接标注平台的对外 API(HMAC-SHA256 签名认证)。
- 参考文档:标注平台对外API接口文档.md
- 功能:列出项目、获取项目详情、数据集导出与下载。
- """
- import hashlib
- import hmac
- import secrets
- import time
- import uuid
- from datetime import datetime
- from pathlib import Path
- from typing import Any
- import httpx
- from app.config import get_settings
- from app.core.db import async_session, DatasetRecord
- from app.core.logging import logger
- settings = get_settings()
- # Token 缓存(内存中)
- _token_cache: dict[str, Any] = {}
- def _get_base_url() -> str:
- base_url = settings.annotation_platform_base_url
- if not base_url:
- raise ValueError("标注平台地址未配置,请检查 ANNOTATION_PLATFORM_BASE_URL")
- return base_url.rstrip("/")
- def _get_credentials() -> tuple[str, str]:
- app_id = settings.annotation_platform_app_id
- app_secret = settings.annotation_platform_app_secret
- if not app_id or not app_secret:
- raise ValueError("标注平台凭证未配置,请检查 ANNOTATION_PLATFORM_APP_ID 和 ANNOTATION_PLATFORM_APP_SECRET")
- return app_id, app_secret
- def _build_token_headers() -> dict[str, str]:
- """构建获取 Token 的 HMAC-SHA256 签名请求头。"""
- app_id, app_secret = _get_credentials()
- timestamp = str(int(time.time()))
- nonce = uuid.uuid4().hex # 使用 uuid4 保证每次唯一,16+ 位随机字符串
- message = app_id + timestamp + nonce
- signature = hmac.new(
- key=app_secret.encode("utf-8"),
- msg=message.encode("utf-8"),
- digestmod=hashlib.sha256,
- ).hexdigest()
- return {
- "Content-Type": "application/json",
- "X-Api-Key": app_id,
- "X-Timestamp": timestamp,
- "X-Nonce": nonce,
- "X-Signature": signature,
- }
- def _is_token_valid() -> bool:
- """检查缓存的 Token 是否仍然有效(提前 5 分钟刷新)。"""
- if not _token_cache.get("access_token"):
- return False
- expires_at = _token_cache.get("expires_at", 0)
- return time.time() < expires_at - 300
- async def _refresh_token() -> str:
- """使用 Bearer Token 刷新 Access Token。
- POST /api/v1/open/auth/refresh
- 比重新签名更高效,仅在 Token 存在但即将过期时调用。
- """
- old_token = _token_cache.get("access_token", "")
- base_url = _get_base_url()
- async with httpx.AsyncClient(timeout=30) as client:
- resp = await client.post(
- f"{base_url}/api/v1/open/auth/refresh",
- headers={"Authorization": f"Bearer {old_token}"},
- )
- resp.raise_for_status()
- body = resp.json()
- if body.get("code") != 0:
- raise RuntimeError(f"刷新标注平台 Token 失败: {body.get('message', body)}")
- data = body.get("data", {})
- _token_cache["access_token"] = data["access_token"]
- _token_cache["expires_in"] = data.get("expires_in", 7200)
- _token_cache["expires_at"] = time.time() + data.get("expires_in", 7200)
- return data["access_token"]
- async def get_token() -> str:
- """获取 Access Token,带缓存和自动刷新。
- 优先级:
- 1. Token 有效 → 直接返回
- 2. Token 存在但快过期 → 调用 /auth/refresh 刷新
- 3. 无 Token → 调用 /auth/token 重新签名获取
- """
- if _is_token_valid():
- return _token_cache["access_token"]
- # Token 存在但即将过期,尝试刷新
- if _token_cache.get("access_token"):
- try:
- return await _refresh_token()
- except Exception as e:
- logger.warning(f"Token 刷新失败,回退到重新获取: {e}")
- # 无 Token 或刷新失败,重新签名获取
- headers = _build_token_headers()
- base_url = _get_base_url()
- async with httpx.AsyncClient(timeout=30) as client:
- resp = await client.post(
- f"{base_url}/api/v1/open/auth/token",
- json={},
- headers=headers,
- )
- resp.raise_for_status()
- body = resp.json()
- # 标注平台返回 code: 0 表示成功
- if body.get("code") != 0:
- raise RuntimeError(f"获取标注平台 Token 失败: {body.get('message', body)}")
- data = body.get("data", {})
- _token_cache["access_token"] = data["access_token"]
- _token_cache["expires_in"] = data.get("expires_in", 7200)
- _token_cache["expires_at"] = time.time() + data.get("expires_in", 7200)
- return data["access_token"]
- def _auth_headers() -> dict[str, str]:
- """构建业务接口的认证请求头。"""
- return {
- "Authorization": f"Bearer {_token_cache.get('access_token', '')}",
- "Content-Type": "application/json",
- }
- async def _request(method: str, path: str, raise_on_error: bool = True, **kwargs) -> dict[str, Any]:
- """统一的业务请求方法,自动携带 Token。
- raise_on_error=False 时,400 等错误不抛异常,返回原始响应体。
- """
- await get_token()
- base_url = _get_base_url()
- async with httpx.AsyncClient(timeout=60) as client:
- resp = await client.request(
- method,
- f"{base_url}{path}",
- headers=_auth_headers(),
- **kwargs,
- )
- if not raise_on_error and resp.status_code >= 400:
- try:
- body = resp.json()
- except Exception:
- body = {"status_code": resp.status_code, "text": resp.text}
- return {"_error": True, "_status_code": resp.status_code, **body}
- resp.raise_for_status()
- body = resp.json()
- if body.get("code") != 0:
- raise RuntimeError(f"标注平台请求失败: {body.get('message', body)}")
- return body.get("data", {})
- # ---------- 项目列表 ----------
- async def list_projects(
- page: int = 1,
- page_size: int = 20,
- name: str | None = None,
- project_type: str | None = None,
- status: str | None = None,
- ) -> dict[str, Any]:
- """获取标注平台项目列表。
- GET /api/v1/open/projects
- """
- params: dict[str, Any] = {"page": page, "page_size": page_size}
- if name:
- params["name"] = name
- if project_type:
- params["type"] = project_type
- if status:
- params["status"] = status
- return await _request("GET", "/api/v1/open/projects", params=params)
- # ---------- 项目详情 ----------
- async def get_project_detail(project_id: str) -> dict[str, Any]:
- """获取项目详情。
- GET /api/v1/open/projects/{project_id}
- """
- return await _request("GET", f"/api/v1/open/projects/{project_id}")
- # ---------- 数据集导出与下载 ----------
- async def import_project_dataset(
- project_id: str,
- project_name: str = "",
- format: str = "alpaca",
- ) -> dict[str, Any]:
- """导出并下载项目数据集,保存到本地并写入数据库。
- 流程:
- 1. 查询项目详情获取 task_type
- 2. POST /api/v1/open/projects/{project_id}/datasets/download → 获取 file_url
- 3. GET /api/v1/open/datasets/downloads/{download_token} → 下载文件
- 4. 保存到 uploads 目录
- 5. 写入 DatasetRecord 数据库
- """
- # 1. 查询项目详情,获取 task_type 和 project_type
- try:
- project_detail = await get_project_detail(project_id)
- task_type = project_detail.get("task_type", "")
- project_type = project_detail.get("project_type", "text")
- logger.info(f"Project {project_id}: task_type={task_type}, project_type={project_type}")
- except Exception as e:
- logger.warning(f"Failed to get project detail: {e}, using default formats")
- task_type = ""
- project_type = "text"
- # 2. 根据项目类型选择导出格式
- # 文本项目: alpaca / sharegpt
- # 图片项目: json / csv / coco / yolo / pascal_voc
- if project_type == "text":
- formats_to_try = ["alpaca", "sharegpt"]
- else:
- formats_to_try = ["json", "csv", "coco"]
- # 如果用户指定了格式,优先使用
- if format and format not in formats_to_try:
- formats_to_try.insert(0, format)
- file_content = b""
- file_name = ""
- total_exported = 0
- used_format = ""
- last_error = ""
- for try_format in formats_to_try:
- # 请求导出
- export_data = await _request(
- "POST",
- f"/api/v1/open/projects/{project_id}/datasets/download",
- raise_on_error=False,
- json={"format": try_format, "completed_only": True},
- )
- # 检查是否返回错误
- if export_data.get("_error"):
- status_code = export_data.get("_status_code", 0)
- error_msg = export_data.get("message", export_data)
- last_error = f"HTTP {status_code}: {error_msg}"
- logger.warning(f"Format '{try_format}' failed: {last_error}, trying next...")
- continue
- logger.info(f"Annotation export response (format={try_format}): {export_data}")
- file_url = export_data.get("file_url", "")
- file_name = export_data.get("file_name", f"{project_id}_{try_format}.json")
- total_exported = export_data.get("total_exported", 0)
- export_status = export_data.get("status", "completed")
- # 检查导出状态
- if export_status != "completed":
- logger.warning(f"Export status={export_status} for format={try_format}, trying next...")
- last_error = f"导出状态: {export_status}"
- continue
- if not file_url:
- logger.warning(f"No file_url for format={try_format}, trying next...")
- last_error = "未返回下载链接"
- continue
- # 从 file_url 中提取 download_token
- if "/datasets/downloads/" in file_url:
- download_token = file_url.split("/datasets/downloads/")[-1].strip("/")
- else:
- download_token = file_url.rstrip("/").split("/")[-1]
- # 下载文件,带轮询(标注平台生成文件可能需要时间)
- await get_token()
- base_url = _get_base_url()
- download_url = f"{base_url}/api/v1/open/datasets/downloads/{download_token}"
- file_content = b""
- max_retries = 4
- for attempt in range(max_retries):
- async with httpx.AsyncClient(timeout=120) as client:
- resp = await client.get(
- download_url,
- headers=_auth_headers(),
- follow_redirects=False,
- )
- redirect_count = 0
- while resp.is_redirect and redirect_count < 5:
- redirect_url = resp.next_request.url
- logger.info(f"Download redirect to: {redirect_url}")
- resp = await client.get(
- str(redirect_url),
- headers=_auth_headers(),
- follow_redirects=False,
- )
- redirect_count += 1
- resp.raise_for_status()
- file_content = resp.content
- if len(file_content) > 10:
- break
- if attempt < max_retries - 1:
- import asyncio
- wait = 2 ** attempt # 1, 2, 4 秒
- logger.info(
- f"Download attempt {attempt + 1}/{max_retries} (format={try_format}): "
- f"file too small ({len(file_content)} bytes), retrying in {wait}s..."
- )
- await asyncio.sleep(wait)
- logger.info(
- f"Downloaded (format={try_format}): {len(file_content)} bytes, "
- f"content_type={resp.headers.get('content-type', 'unknown')}"
- )
- if len(file_content) > 10:
- used_format = try_format
- break
- logger.warning(
- f"Format '{try_format}' returned empty file ({len(file_content)} bytes), "
- f"trying next format..."
- )
- last_error = f"格式 {try_format} 导出文件为空"
- if len(file_content) <= 10:
- logger.warning(f"Annotation file content: {file_content!r}")
- raise RuntimeError(
- f"标注平台导出文件为空(task_type={task_type}, 尝试了格式: {formats_to_try}),"
- f"total_exported={total_exported}。最后错误: {last_error}"
- )
- # 保存到 uploads 目录
- upload_dir = settings.uploads_dir
- upload_dir.mkdir(parents=True, exist_ok=True)
- safe_name = "".join(c if c.isalnum() or c in "._-" else "_" for c in (project_name or project_id))
- file_path = upload_dir / f"{safe_name}.jsonl"
- if file_path.exists():
- file_path = upload_dir / f"{uuid.uuid4().hex[:8]}_{safe_name}.jsonl"
- file_path.write_bytes(file_content)
- # 统一转为 JSONL 格式
- jsonl_path = _convert_to_jsonl(file_path)
- record_count = _count_records(jsonl_path, "jsonl")
- logger.info(f"Annotation file converted: {jsonl_path.name}, record_count={record_count}, format={used_format}")
- # 6. 写入数据库(格式统一为 jsonl)
- record_id = str(uuid.uuid4())
- record = DatasetRecord(
- id=record_id,
- name=project_name or project_id,
- format="jsonl",
- record_count=record_count,
- file_path=str(jsonl_path),
- created_at=datetime.utcnow(),
- )
- async with async_session() as session:
- session.add(record)
- await session.commit()
- logger.info(f"Imported dataset from annotation platform: {project_id} -> {jsonl_path.name} ({record_count} records)")
- return {
- "project_id": project_id,
- "project_name": project_name or project_id,
- "format": "jsonl",
- "total_exported": total_exported,
- "dataset_id": record_id,
- "dataset_name": jsonl_path.name,
- }
- def _convert_to_jsonl(file_path: Path) -> Path:
- """将 JSON/JSONL 文件统一转为 JSONL 格式。"""
- import json as _json
- jsonl_path = file_path.with_suffix(".jsonl")
- with open(file_path, "r", encoding="utf-8") as f:
- content = f.read().strip()
- if not content:
- jsonl_path.touch()
- return jsonl_path
- try:
- # 尝试作为 JSON 解析
- data = _json.loads(content)
- if isinstance(data, list):
- # JSON 数组
- items = data
- elif isinstance(data, dict):
- # JSON 对象:查找嵌套的数组字段
- items = None
- for key in ("data", "items", "results", "records", "annotations", "samples"):
- if key in data and isinstance(data[key], list):
- items = data[key]
- break
- if items is None:
- # 单个对象,包装为数组
- items = [data]
- else:
- items = None
- if items is not None:
- with open(jsonl_path, "w", encoding="utf-8") as out:
- for item in items:
- out.write(_json.dumps(item, ensure_ascii=False) + "\n")
- # 只有当原始文件与新文件不同时才删除(避免删除刚写入的文件)
- if jsonl_path != file_path and file_path.exists():
- file_path.unlink()
- return jsonl_path
- except _json.JSONDecodeError:
- pass
- # 不是标准 JSON,可能是 JSONL,逐行验证
- lines = content.split("\n")
- valid_lines = []
- for line in lines:
- line = line.strip()
- if line:
- try:
- _json.loads(line)
- valid_lines.append(line)
- except _json.JSONDecodeError:
- continue # 跳过无效行
- with open(jsonl_path, "w", encoding="utf-8") as out:
- out.write("\n".join(valid_lines) + ("\n" if valid_lines else ""))
- if jsonl_path != file_path and file_path.exists():
- file_path.unlink()
- return jsonl_path
- def _detect_format(filename: str) -> str:
- """根据文件名推断格式。"""
- name = filename.lower()
- if name.endswith(".jsonl"):
- return "jsonl"
- if name.endswith(".csv"):
- return "csv"
- if name.endswith(".parquet"):
- return "parquet"
- if name.endswith(".json"):
- return "json"
- return "json"
- def _count_records(file_path: Path, fmt: str) -> int:
- """计算文件中的记录数。"""
- import json
- if not file_path.exists():
- return 0
- try:
- if fmt == "jsonl":
- with open(file_path, "r", encoding="utf-8") as f:
- return sum(1 for line in f if line.strip())
- elif fmt == "json":
- with open(file_path, "r", encoding="utf-8") as f:
- data = json.load(f)
- if isinstance(data, list):
- return len(data)
- return 1
- elif fmt == "csv":
- import csv
- with open(file_path, "r", encoding="utf-8") as f:
- return sum(1 for _ in csv.DictReader(f))
- except Exception:
- return 0
- return 0
|