| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437 |
- """标注平台 API 客户端服务。
- 对接标注平台的对外 API(HMAC-SHA256 签名认证)。
- 参考文档:标注平台对外API接口文档.md
- 功能:列出项目、获取项目详情、数据集导出与下载。
- """
- import asyncio
- import hashlib
- import hmac
- import json
- import time
- import uuid
- from datetime import datetime
- from pathlib import Path
- from typing import Any
- import httpx
- from app.config import get_settings
- from app.core.db import async_session, DatasetRecord
- from app.core.logging import logger
- settings = get_settings()
- # Token 缓存(内存中)
- _token_cache: dict[str, Any] = {}
- def _get_base_url() -> str:
- base_url = settings.annotation_platform_base_url
- if not base_url:
- raise ValueError("标注平台地址未配置,请检查 ANNOTATION_PLATFORM_BASE_URL")
- return base_url.rstrip("/")
- def _get_credentials() -> tuple[str, str]:
- app_id = settings.annotation_platform_app_id
- app_secret = settings.annotation_platform_app_secret
- if not app_id or not app_secret:
- raise ValueError("标注平台凭证未配置,请检查 ANNOTATION_PLATFORM_APP_ID 和 ANNOTATION_PLATFORM_APP_SECRET")
- return app_id, app_secret
- def _build_token_headers() -> dict[str, str]:
- """构建获取 Token 的 HMAC-SHA256 签名请求头。"""
- app_id, app_secret = _get_credentials()
- timestamp = str(int(time.time()))
- nonce = uuid.uuid4().hex
- message = app_id + timestamp + nonce
- signature = hmac.new(
- key=app_secret.encode("utf-8"),
- msg=message.encode("utf-8"),
- digestmod=hashlib.sha256,
- ).hexdigest()
- return {
- "Content-Type": "application/json",
- "X-Api-Key": app_id,
- "X-Timestamp": timestamp,
- "X-Nonce": nonce,
- "X-Signature": signature,
- }
- def _is_token_valid() -> bool:
- """检查缓存的 Token 是否仍然有效(提前 5 分钟刷新)。"""
- if not _token_cache.get("access_token"):
- return False
- expires_at = _token_cache.get("expires_at", 0)
- return time.time() < expires_at - 300
- async def _refresh_token() -> str:
- """使用 Bearer Token 刷新 Access Token。"""
- old_token = _token_cache.get("access_token", "")
- base_url = _get_base_url()
- async with httpx.AsyncClient(timeout=30) as client:
- resp = await client.post(
- f"{base_url}/api/v1/open/auth/refresh",
- headers={"Authorization": f"Bearer {old_token}"},
- )
- resp.raise_for_status()
- body = resp.json()
- if body.get("code") != 0:
- raise RuntimeError(f"刷新标注平台 Token 失败: {body.get('message', body)}")
- data = body.get("data", {})
- _token_cache["access_token"] = data["access_token"]
- _token_cache["expires_at"] = time.time() + data.get("expires_in", 7200)
- return data["access_token"]
- async def get_token() -> str:
- """获取 Access Token,带缓存和自动刷新。"""
- if _is_token_valid():
- return _token_cache["access_token"]
- # Token 存在但即将过期,尝试刷新
- if _token_cache.get("access_token"):
- try:
- return await _refresh_token()
- except Exception as e:
- logger.warning(f"Token 刷新失败,回退到重新获取: {e}")
- # 无 Token 或刷新失败,重新签名获取
- headers = _build_token_headers()
- base_url = _get_base_url()
- async with httpx.AsyncClient(timeout=30) as client:
- resp = await client.post(
- f"{base_url}/api/v1/open/auth/token",
- json={},
- headers=headers,
- )
- resp.raise_for_status()
- body = resp.json()
- if body.get("code") != 0:
- raise RuntimeError(f"获取标注平台 Token 失败: {body.get('message', body)}")
- data = body.get("data", {})
- _token_cache["access_token"] = data["access_token"]
- _token_cache["expires_at"] = time.time() + data.get("expires_in", 7200)
- return data["access_token"]
- def _auth_headers() -> dict[str, str]:
- """构建业务接口的认证请求头。"""
- return {
- "Authorization": f"Bearer {_token_cache.get('access_token', '')}",
- "Content-Type": "application/json",
- }
- async def _request(
- method: str, path: str, *, raise_on_error: bool = True, **kwargs
- ) -> dict[str, Any]:
- """统一的业务请求方法,自动携带 Token。
- raise_on_error=False 时,HTTP 4xx/5xx 不抛异常,返回带 _error 标记的 dict。
- """
- await get_token()
- base_url = _get_base_url()
- async with httpx.AsyncClient(timeout=60) as client:
- resp = await client.request(
- method,
- f"{base_url}{path}",
- headers=_auth_headers(),
- **kwargs,
- )
- if not raise_on_error and resp.status_code >= 400:
- try:
- body = resp.json()
- except Exception:
- body = {"message": resp.text}
- return {"_error": True, "_status_code": resp.status_code, **body}
- resp.raise_for_status()
- body = resp.json()
- if body.get("code") != 0:
- raise RuntimeError(f"标注平台请求失败: {body.get('message', body)}")
- return body.get("data", {})
- # ---------- 项目列表 ----------
- async def list_projects(
- page: int = 1,
- page_size: int = 20,
- name: str | None = None,
- project_type: str | None = None,
- status: str | None = None,
- ) -> dict[str, Any]:
- """获取标注平台项目列表。"""
- params: dict[str, Any] = {"page": page, "page_size": page_size}
- if name:
- params["name"] = name
- if project_type:
- params["type"] = project_type
- if status:
- params["status"] = status
- return await _request("GET", "/api/v1/open/projects", params=params)
- # ---------- 项目详情 ----------
- async def get_project_detail(project_id: str) -> dict[str, Any]:
- """获取项目详情。"""
- return await _request("GET", f"/api/v1/open/projects/{project_id}")
- # ---------- 数据集导出与下载 ----------
- # 文本项目和图片项目支持的导出格式(参考标注平台 API 文档)
- _TEXT_FORMATS = ["alpaca", "sharegpt"]
- _IMAGE_FORMATS = ["json", "csv", "coco"]
- async def _request_export(project_id: str, fmt: str) -> dict[str, Any] | None:
- """请求导出并返回导出信息,格式不兼容或失败时返回 None。"""
- data = await _request(
- "POST",
- f"/api/v1/open/projects/{project_id}/datasets/download",
- raise_on_error=False,
- json={"format": fmt, "completed_only": True},
- )
- if data.get("_error"):
- logger.info(f"导出格式 '{fmt}' 不可用: {data.get('message', data.get('detail', ''))}")
- return None
- if data.get("status") != "completed":
- logger.info(f"导出格式 '{fmt}' 状态异常: {data.get('status')}")
- return None
- if not data.get("file_url"):
- logger.info(f"导出格式 '{fmt}' 未返回下载链接")
- return None
- return data
- async def _download_file(download_token: str, max_retries: int = 3) -> bytes:
- """通过 download_token 下载导出文件,带重试(标注平台可能需要时间生成文件)。"""
- await get_token()
- base_url = _get_base_url()
- url = f"{base_url}/api/v1/open/datasets/downloads/{download_token}"
- for attempt in range(max_retries):
- async with httpx.AsyncClient(timeout=120) as client:
- resp = await client.get(url, headers=_auth_headers(), follow_redirects=True)
- resp.raise_for_status()
- content = resp.content
- if len(content) > 10:
- return content
- if attempt < max_retries - 1:
- wait = 2 ** attempt # 1, 2 秒
- logger.info(f"文件尚未就绪 ({len(content)} bytes),{wait}s 后重试...")
- await asyncio.sleep(wait)
- return content
- def _extract_download_token(file_url: str) -> str:
- """从 file_url 中提取 download_token。"""
- if "/datasets/downloads/" in file_url:
- return file_url.split("/datasets/downloads/")[-1].strip("/")
- return file_url.rstrip("/").split("/")[-1]
- async def import_project_dataset(
- project_id: str,
- project_name: str = "",
- format: str = "alpaca",
- ) -> dict[str, Any]:
- """导出并下载项目数据集,保存到本地并写入数据库。
- 流程:
- 1. 查询项目详情获取 project_type
- 2. 根据项目类型选择合适的导出格式,依次尝试
- 3. 下载文件并转换为 JSONL
- 4. 写入数据库
- """
- # 1. 查询项目类型,决定可用格式
- try:
- detail = await get_project_detail(project_id)
- project_type = detail.get("project_type", "text")
- except Exception:
- project_type = "text"
- # 2. 构建格式尝试列表(用户指定的格式优先)
- if project_type == "text":
- formats_to_try = list(_TEXT_FORMATS)
- else:
- formats_to_try = list(_IMAGE_FORMATS)
- if format and format not in formats_to_try:
- formats_to_try.insert(0, format)
- # 3. 依次尝试各格式:请求导出 → 下载文件
- file_content = b""
- total_exported = 0
- used_format = ""
- for fmt in formats_to_try:
- export_data = await _request_export(project_id, fmt)
- if not export_data:
- continue
- total_exported = export_data.get("total_exported", 0)
- download_token = _extract_download_token(export_data["file_url"])
- file_content = await _download_file(download_token)
- if len(file_content) > 10:
- used_format = fmt
- logger.info(
- f"标注平台导出成功: format={fmt}, {len(file_content)} bytes, "
- f"total_exported={total_exported}"
- )
- break
- logger.info(f"格式 '{fmt}' 导出文件为空 ({len(file_content)} bytes),尝试下一格式")
- if len(file_content) <= 10:
- raise RuntimeError(
- f"标注平台所有导出格式均返回空文件(project_type={project_type},"
- f"尝试格式: {formats_to_try}),请检查标注平台该项目的数据是否支持导出"
- )
- # 4. 保存到 uploads 目录
- upload_dir = settings.uploads_dir
- upload_dir.mkdir(parents=True, exist_ok=True)
- safe_name = "".join(c if c.isalnum() or c in "._-" else "_" for c in (project_name or project_id))
- file_path = upload_dir / f"{safe_name}.jsonl"
- if file_path.exists():
- file_path = upload_dir / f"{uuid.uuid4().hex[:8]}_{safe_name}.jsonl"
- file_path.write_bytes(file_content)
- # 5. 转为 JSONL 格式
- jsonl_path = _convert_to_jsonl(file_path)
- record_count = _count_records(jsonl_path, "jsonl")
- # 6. 写入数据库
- record_id = str(uuid.uuid4())
- record = DatasetRecord(
- id=record_id,
- name=project_name or project_id,
- format="jsonl",
- record_count=record_count,
- file_path=str(jsonl_path),
- created_at=datetime.utcnow(),
- )
- async with async_session() as session:
- session.add(record)
- await session.commit()
- logger.info(f"Imported annotation dataset: {project_name} ({record_count} records, format={used_format})")
- return {
- "project_id": project_id,
- "project_name": project_name or project_id,
- "format": "jsonl",
- "total_exported": total_exported,
- "dataset_id": record_id,
- "dataset_name": jsonl_path.name,
- }
- def _convert_to_jsonl(file_path: Path) -> Path:
- """将 JSON/JSONL 文件统一转为 JSONL 格式。"""
- jsonl_path = file_path.with_suffix(".jsonl")
- with open(file_path, "r", encoding="utf-8") as f:
- content = f.read().strip()
- if not content:
- jsonl_path.touch()
- return jsonl_path
- try:
- data = json.loads(content)
- items = _extract_items(data)
- if items is not None:
- with open(jsonl_path, "w", encoding="utf-8") as out:
- for item in items:
- out.write(json.dumps(item, ensure_ascii=False) + "\n")
- if jsonl_path != file_path and file_path.exists():
- file_path.unlink()
- return jsonl_path
- except json.JSONDecodeError:
- pass
- # JSONL 逐行验证
- valid_lines = [line.strip() for line in content.split("\n") if line.strip()]
- valid_lines = [line for line in valid_lines if _is_valid_json(line)]
- with open(jsonl_path, "w", encoding="utf-8") as out:
- out.write("\n".join(valid_lines) + ("\n" if valid_lines else ""))
- if jsonl_path != file_path and file_path.exists():
- file_path.unlink()
- return jsonl_path
- def _extract_items(data) -> list | None:
- """从 JSON 数据中提取记录列表。"""
- if isinstance(data, list):
- return data
- if isinstance(data, dict):
- for key in ("data", "items", "results", "records", "annotations", "samples"):
- if key in data and isinstance(data[key], list):
- return data[key]
- return [data]
- return None
- def _is_valid_json(s: str) -> bool:
- try:
- json.loads(s)
- return True
- except json.JSONDecodeError:
- return False
- def _count_records(file_path: Path, fmt: str) -> int:
- """计算文件中的记录数。"""
- if not file_path.exists():
- return 0
- try:
- if fmt == "jsonl":
- with open(file_path, "r", encoding="utf-8") as f:
- return sum(1 for line in f if line.strip())
- elif fmt == "json":
- with open(file_path, "r", encoding="utf-8") as f:
- data = json.load(f)
- return len(data) if isinstance(data, list) else 1
- elif fmt == "csv":
- import csv
- with open(file_path, "r", encoding="utf-8") as f:
- return sum(1 for _ in csv.DictReader(f))
- except Exception:
- return 0
- return 0
|