|
|
@@ -5,9 +5,10 @@
|
|
|
功能:列出项目、获取项目详情、数据集导出与下载。
|
|
|
"""
|
|
|
|
|
|
+import asyncio
|
|
|
import hashlib
|
|
|
import hmac
|
|
|
-import secrets
|
|
|
+import json
|
|
|
import time
|
|
|
import uuid
|
|
|
from datetime import datetime
|
|
|
@@ -45,7 +46,7 @@ def _build_token_headers() -> dict[str, str]:
|
|
|
"""构建获取 Token 的 HMAC-SHA256 签名请求头。"""
|
|
|
app_id, app_secret = _get_credentials()
|
|
|
timestamp = str(int(time.time()))
|
|
|
- nonce = uuid.uuid4().hex # 使用 uuid4 保证每次唯一,16+ 位随机字符串
|
|
|
+ nonce = uuid.uuid4().hex
|
|
|
message = app_id + timestamp + nonce
|
|
|
signature = hmac.new(
|
|
|
key=app_secret.encode("utf-8"),
|
|
|
@@ -71,11 +72,7 @@ def _is_token_valid() -> bool:
|
|
|
|
|
|
|
|
|
async def _refresh_token() -> str:
|
|
|
- """使用 Bearer Token 刷新 Access Token。
|
|
|
-
|
|
|
- POST /api/v1/open/auth/refresh
|
|
|
- 比重新签名更高效,仅在 Token 存在但即将过期时调用。
|
|
|
- """
|
|
|
+ """使用 Bearer Token 刷新 Access Token。"""
|
|
|
old_token = _token_cache.get("access_token", "")
|
|
|
base_url = _get_base_url()
|
|
|
|
|
|
@@ -92,20 +89,13 @@ async def _refresh_token() -> str:
|
|
|
|
|
|
data = body.get("data", {})
|
|
|
_token_cache["access_token"] = data["access_token"]
|
|
|
- _token_cache["expires_in"] = data.get("expires_in", 7200)
|
|
|
_token_cache["expires_at"] = time.time() + data.get("expires_in", 7200)
|
|
|
|
|
|
return data["access_token"]
|
|
|
|
|
|
|
|
|
async def get_token() -> str:
|
|
|
- """获取 Access Token,带缓存和自动刷新。
|
|
|
-
|
|
|
- 优先级:
|
|
|
- 1. Token 有效 → 直接返回
|
|
|
- 2. Token 存在但快过期 → 调用 /auth/refresh 刷新
|
|
|
- 3. 无 Token → 调用 /auth/token 重新签名获取
|
|
|
- """
|
|
|
+ """获取 Access Token,带缓存和自动刷新。"""
|
|
|
if _is_token_valid():
|
|
|
return _token_cache["access_token"]
|
|
|
|
|
|
@@ -129,13 +119,11 @@ async def get_token() -> str:
|
|
|
resp.raise_for_status()
|
|
|
body = resp.json()
|
|
|
|
|
|
- # 标注平台返回 code: 0 表示成功
|
|
|
if body.get("code") != 0:
|
|
|
raise RuntimeError(f"获取标注平台 Token 失败: {body.get('message', body)}")
|
|
|
|
|
|
data = body.get("data", {})
|
|
|
_token_cache["access_token"] = data["access_token"]
|
|
|
- _token_cache["expires_in"] = data.get("expires_in", 7200)
|
|
|
_token_cache["expires_at"] = time.time() + data.get("expires_in", 7200)
|
|
|
|
|
|
return data["access_token"]
|
|
|
@@ -149,10 +137,12 @@ def _auth_headers() -> dict[str, str]:
|
|
|
}
|
|
|
|
|
|
|
|
|
-async def _request(method: str, path: str, raise_on_error: bool = True, **kwargs) -> dict[str, Any]:
|
|
|
+async def _request(
|
|
|
+ method: str, path: str, *, raise_on_error: bool = True, **kwargs
|
|
|
+) -> dict[str, Any]:
|
|
|
"""统一的业务请求方法,自动携带 Token。
|
|
|
|
|
|
- raise_on_error=False 时,400 等错误不抛异常,返回原始响应体。
|
|
|
+ raise_on_error=False 时,HTTP 4xx/5xx 不抛异常,返回带 _error 标记的 dict。
|
|
|
"""
|
|
|
await get_token()
|
|
|
base_url = _get_base_url()
|
|
|
@@ -169,7 +159,7 @@ async def _request(method: str, path: str, raise_on_error: bool = True, **kwargs
|
|
|
try:
|
|
|
body = resp.json()
|
|
|
except Exception:
|
|
|
- body = {"status_code": resp.status_code, "text": resp.text}
|
|
|
+ body = {"message": resp.text}
|
|
|
return {"_error": True, "_status_code": resp.status_code, **body}
|
|
|
|
|
|
resp.raise_for_status()
|
|
|
@@ -190,10 +180,7 @@ async def list_projects(
|
|
|
project_type: str | None = None,
|
|
|
status: str | None = None,
|
|
|
) -> dict[str, Any]:
|
|
|
- """获取标注平台项目列表。
|
|
|
-
|
|
|
- GET /api/v1/open/projects
|
|
|
- """
|
|
|
+ """获取标注平台项目列表。"""
|
|
|
params: dict[str, Any] = {"page": page, "page_size": page_size}
|
|
|
if name:
|
|
|
params["name"] = name
|
|
|
@@ -208,15 +195,71 @@ async def list_projects(
|
|
|
# ---------- 项目详情 ----------
|
|
|
|
|
|
async def get_project_detail(project_id: str) -> dict[str, Any]:
|
|
|
- """获取项目详情。
|
|
|
-
|
|
|
- GET /api/v1/open/projects/{project_id}
|
|
|
- """
|
|
|
+ """获取项目详情。"""
|
|
|
return await _request("GET", f"/api/v1/open/projects/{project_id}")
|
|
|
|
|
|
|
|
|
# ---------- 数据集导出与下载 ----------
|
|
|
|
|
|
+# 文本项目和图片项目支持的导出格式(参考标注平台 API 文档)
|
|
|
+_TEXT_FORMATS = ["alpaca", "sharegpt"]
|
|
|
+_IMAGE_FORMATS = ["json", "csv", "coco"]
|
|
|
+
|
|
|
+
|
|
|
+async def _request_export(project_id: str, fmt: str) -> dict[str, Any] | None:
|
|
|
+ """请求导出并返回导出信息,格式不兼容或失败时返回 None。"""
|
|
|
+ data = await _request(
|
|
|
+ "POST",
|
|
|
+ f"/api/v1/open/projects/{project_id}/datasets/download",
|
|
|
+ raise_on_error=False,
|
|
|
+ json={"format": fmt, "completed_only": True},
|
|
|
+ )
|
|
|
+
|
|
|
+ if data.get("_error"):
|
|
|
+ logger.info(f"导出格式 '{fmt}' 不可用: {data.get('message', data.get('detail', ''))}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ if data.get("status") != "completed":
|
|
|
+ logger.info(f"导出格式 '{fmt}' 状态异常: {data.get('status')}")
|
|
|
+ return None
|
|
|
+
|
|
|
+ if not data.get("file_url"):
|
|
|
+ logger.info(f"导出格式 '{fmt}' 未返回下载链接")
|
|
|
+ return None
|
|
|
+
|
|
|
+ return data
|
|
|
+
|
|
|
+
|
|
|
+async def _download_file(download_token: str, max_retries: int = 3) -> bytes:
|
|
|
+ """通过 download_token 下载导出文件,带重试(标注平台可能需要时间生成文件)。"""
|
|
|
+ await get_token()
|
|
|
+ base_url = _get_base_url()
|
|
|
+ url = f"{base_url}/api/v1/open/datasets/downloads/{download_token}"
|
|
|
+
|
|
|
+ for attempt in range(max_retries):
|
|
|
+ async with httpx.AsyncClient(timeout=120) as client:
|
|
|
+ resp = await client.get(url, headers=_auth_headers(), follow_redirects=True)
|
|
|
+ resp.raise_for_status()
|
|
|
+ content = resp.content
|
|
|
+
|
|
|
+ if len(content) > 10:
|
|
|
+ return content
|
|
|
+
|
|
|
+ if attempt < max_retries - 1:
|
|
|
+ wait = 2 ** attempt # 1, 2 秒
|
|
|
+ logger.info(f"文件尚未就绪 ({len(content)} bytes),{wait}s 后重试...")
|
|
|
+ await asyncio.sleep(wait)
|
|
|
+
|
|
|
+ return content
|
|
|
+
|
|
|
+
|
|
|
+def _extract_download_token(file_url: str) -> str:
|
|
|
+ """从 file_url 中提取 download_token。"""
|
|
|
+ if "/datasets/downloads/" in file_url:
|
|
|
+ return file_url.split("/datasets/downloads/")[-1].strip("/")
|
|
|
+ return file_url.rstrip("/").split("/")[-1]
|
|
|
+
|
|
|
+
|
|
|
async def import_project_dataset(
|
|
|
project_id: str,
|
|
|
project_name: str = "",
|
|
|
@@ -225,144 +268,58 @@ async def import_project_dataset(
|
|
|
"""导出并下载项目数据集,保存到本地并写入数据库。
|
|
|
|
|
|
流程:
|
|
|
- 1. 查询项目详情获取 task_type
|
|
|
- 2. POST /api/v1/open/projects/{project_id}/datasets/download → 获取 file_url
|
|
|
- 3. GET /api/v1/open/datasets/downloads/{download_token} → 下载文件
|
|
|
- 4. 保存到 uploads 目录
|
|
|
- 5. 写入 DatasetRecord 数据库
|
|
|
+ 1. 查询项目详情获取 project_type
|
|
|
+ 2. 根据项目类型选择合适的导出格式,依次尝试
|
|
|
+ 3. 下载文件并转换为 JSONL
|
|
|
+ 4. 写入数据库
|
|
|
"""
|
|
|
- # 1. 查询项目详情,获取 task_type 和 project_type
|
|
|
+ # 1. 查询项目类型,决定可用格式
|
|
|
try:
|
|
|
- project_detail = await get_project_detail(project_id)
|
|
|
- task_type = project_detail.get("task_type", "")
|
|
|
- project_type = project_detail.get("project_type", "text")
|
|
|
- logger.info(f"Project {project_id}: task_type={task_type}, project_type={project_type}")
|
|
|
- except Exception as e:
|
|
|
- logger.warning(f"Failed to get project detail: {e}, using default formats")
|
|
|
- task_type = ""
|
|
|
+ detail = await get_project_detail(project_id)
|
|
|
+ project_type = detail.get("project_type", "text")
|
|
|
+ except Exception:
|
|
|
project_type = "text"
|
|
|
|
|
|
- # 2. 根据项目类型选择导出格式
|
|
|
- # 文本项目: alpaca / sharegpt
|
|
|
- # 图片项目: json / csv / coco / yolo / pascal_voc
|
|
|
+ # 2. 构建格式尝试列表(用户指定的格式优先)
|
|
|
if project_type == "text":
|
|
|
- formats_to_try = ["alpaca", "sharegpt"]
|
|
|
+ formats_to_try = list(_TEXT_FORMATS)
|
|
|
else:
|
|
|
- formats_to_try = ["json", "csv", "coco"]
|
|
|
+ formats_to_try = list(_IMAGE_FORMATS)
|
|
|
|
|
|
- # 如果用户指定了格式,优先使用
|
|
|
if format and format not in formats_to_try:
|
|
|
formats_to_try.insert(0, format)
|
|
|
|
|
|
+ # 3. 依次尝试各格式:请求导出 → 下载文件
|
|
|
file_content = b""
|
|
|
- file_name = ""
|
|
|
total_exported = 0
|
|
|
used_format = ""
|
|
|
- last_error = ""
|
|
|
-
|
|
|
- for try_format in formats_to_try:
|
|
|
- # 请求导出
|
|
|
- export_data = await _request(
|
|
|
- "POST",
|
|
|
- f"/api/v1/open/projects/{project_id}/datasets/download",
|
|
|
- raise_on_error=False,
|
|
|
- json={"format": try_format, "completed_only": True},
|
|
|
- )
|
|
|
|
|
|
- # 检查是否返回错误
|
|
|
- if export_data.get("_error"):
|
|
|
- status_code = export_data.get("_status_code", 0)
|
|
|
- error_msg = export_data.get("message", export_data)
|
|
|
- last_error = f"HTTP {status_code}: {error_msg}"
|
|
|
- logger.warning(f"Format '{try_format}' failed: {last_error}, trying next...")
|
|
|
+ for fmt in formats_to_try:
|
|
|
+ export_data = await _request_export(project_id, fmt)
|
|
|
+ if not export_data:
|
|
|
continue
|
|
|
|
|
|
- logger.info(f"Annotation export response (format={try_format}): {export_data}")
|
|
|
-
|
|
|
- file_url = export_data.get("file_url", "")
|
|
|
- file_name = export_data.get("file_name", f"{project_id}_{try_format}.json")
|
|
|
total_exported = export_data.get("total_exported", 0)
|
|
|
- export_status = export_data.get("status", "completed")
|
|
|
-
|
|
|
- # 检查导出状态
|
|
|
- if export_status != "completed":
|
|
|
- logger.warning(f"Export status={export_status} for format={try_format}, trying next...")
|
|
|
- last_error = f"导出状态: {export_status}"
|
|
|
- continue
|
|
|
-
|
|
|
- if not file_url:
|
|
|
- logger.warning(f"No file_url for format={try_format}, trying next...")
|
|
|
- last_error = "未返回下载链接"
|
|
|
- continue
|
|
|
-
|
|
|
- # 从 file_url 中提取 download_token
|
|
|
- if "/datasets/downloads/" in file_url:
|
|
|
- download_token = file_url.split("/datasets/downloads/")[-1].strip("/")
|
|
|
- else:
|
|
|
- download_token = file_url.rstrip("/").split("/")[-1]
|
|
|
-
|
|
|
- # 下载文件,带轮询(标注平台生成文件可能需要时间)
|
|
|
- await get_token()
|
|
|
- base_url = _get_base_url()
|
|
|
- download_url = f"{base_url}/api/v1/open/datasets/downloads/{download_token}"
|
|
|
-
|
|
|
- file_content = b""
|
|
|
- max_retries = 4
|
|
|
- for attempt in range(max_retries):
|
|
|
- async with httpx.AsyncClient(timeout=120) as client:
|
|
|
- resp = await client.get(
|
|
|
- download_url,
|
|
|
- headers=_auth_headers(),
|
|
|
- follow_redirects=False,
|
|
|
- )
|
|
|
- redirect_count = 0
|
|
|
- while resp.is_redirect and redirect_count < 5:
|
|
|
- redirect_url = resp.next_request.url
|
|
|
- logger.info(f"Download redirect to: {redirect_url}")
|
|
|
- resp = await client.get(
|
|
|
- str(redirect_url),
|
|
|
- headers=_auth_headers(),
|
|
|
- follow_redirects=False,
|
|
|
- )
|
|
|
- redirect_count += 1
|
|
|
- resp.raise_for_status()
|
|
|
- file_content = resp.content
|
|
|
-
|
|
|
- if len(file_content) > 10:
|
|
|
- break
|
|
|
-
|
|
|
- if attempt < max_retries - 1:
|
|
|
- import asyncio
|
|
|
- wait = 2 ** attempt # 1, 2, 4 秒
|
|
|
- logger.info(
|
|
|
- f"Download attempt {attempt + 1}/{max_retries} (format={try_format}): "
|
|
|
- f"file too small ({len(file_content)} bytes), retrying in {wait}s..."
|
|
|
- )
|
|
|
- await asyncio.sleep(wait)
|
|
|
-
|
|
|
- logger.info(
|
|
|
- f"Downloaded (format={try_format}): {len(file_content)} bytes, "
|
|
|
- f"content_type={resp.headers.get('content-type', 'unknown')}"
|
|
|
- )
|
|
|
+ download_token = _extract_download_token(export_data["file_url"])
|
|
|
+ file_content = await _download_file(download_token)
|
|
|
|
|
|
if len(file_content) > 10:
|
|
|
- used_format = try_format
|
|
|
+ used_format = fmt
|
|
|
+ logger.info(
|
|
|
+ f"标注平台导出成功: format={fmt}, {len(file_content)} bytes, "
|
|
|
+ f"total_exported={total_exported}"
|
|
|
+ )
|
|
|
break
|
|
|
|
|
|
- logger.warning(
|
|
|
- f"Format '{try_format}' returned empty file ({len(file_content)} bytes), "
|
|
|
- f"trying next format..."
|
|
|
- )
|
|
|
- last_error = f"格式 {try_format} 导出文件为空"
|
|
|
+ logger.info(f"格式 '{fmt}' 导出文件为空 ({len(file_content)} bytes),尝试下一格式")
|
|
|
|
|
|
if len(file_content) <= 10:
|
|
|
- logger.warning(f"Annotation file content: {file_content!r}")
|
|
|
raise RuntimeError(
|
|
|
- f"标注平台导出文件为空(task_type={task_type}, 尝试了格式: {formats_to_try}),"
|
|
|
- f"total_exported={total_exported}。最后错误: {last_error}"
|
|
|
+ f"标注平台所有导出格式均返回空文件(project_type={project_type},"
|
|
|
+ f"尝试格式: {formats_to_try}),请检查标注平台该项目的数据是否支持导出"
|
|
|
)
|
|
|
|
|
|
- # 保存到 uploads 目录
|
|
|
+ # 4. 保存到 uploads 目录
|
|
|
upload_dir = settings.uploads_dir
|
|
|
upload_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
@@ -373,13 +330,11 @@ async def import_project_dataset(
|
|
|
|
|
|
file_path.write_bytes(file_content)
|
|
|
|
|
|
- # 统一转为 JSONL 格式
|
|
|
+ # 5. 转为 JSONL 格式
|
|
|
jsonl_path = _convert_to_jsonl(file_path)
|
|
|
record_count = _count_records(jsonl_path, "jsonl")
|
|
|
|
|
|
- logger.info(f"Annotation file converted: {jsonl_path.name}, record_count={record_count}, format={used_format}")
|
|
|
-
|
|
|
- # 6. 写入数据库(格式统一为 jsonl)
|
|
|
+ # 6. 写入数据库
|
|
|
record_id = str(uuid.uuid4())
|
|
|
record = DatasetRecord(
|
|
|
id=record_id,
|
|
|
@@ -393,7 +348,7 @@ async def import_project_dataset(
|
|
|
session.add(record)
|
|
|
await session.commit()
|
|
|
|
|
|
- logger.info(f"Imported dataset from annotation platform: {project_id} -> {jsonl_path.name} ({record_count} records)")
|
|
|
+ logger.info(f"Imported annotation dataset: {project_name} ({record_count} records, format={used_format})")
|
|
|
|
|
|
return {
|
|
|
"project_id": project_id,
|
|
|
@@ -407,8 +362,6 @@ async def import_project_dataset(
|
|
|
|
|
|
def _convert_to_jsonl(file_path: Path) -> Path:
|
|
|
"""将 JSON/JSONL 文件统一转为 JSONL 格式。"""
|
|
|
- import json as _json
|
|
|
-
|
|
|
jsonl_path = file_path.with_suffix(".jsonl")
|
|
|
with open(file_path, "r", encoding="utf-8") as f:
|
|
|
content = f.read().strip()
|
|
|
@@ -418,46 +371,21 @@ def _convert_to_jsonl(file_path: Path) -> Path:
|
|
|
return jsonl_path
|
|
|
|
|
|
try:
|
|
|
- # 尝试作为 JSON 解析
|
|
|
- data = _json.loads(content)
|
|
|
- if isinstance(data, list):
|
|
|
- # JSON 数组
|
|
|
- items = data
|
|
|
- elif isinstance(data, dict):
|
|
|
- # JSON 对象:查找嵌套的数组字段
|
|
|
- items = None
|
|
|
- for key in ("data", "items", "results", "records", "annotations", "samples"):
|
|
|
- if key in data and isinstance(data[key], list):
|
|
|
- items = data[key]
|
|
|
- break
|
|
|
- if items is None:
|
|
|
- # 单个对象,包装为数组
|
|
|
- items = [data]
|
|
|
- else:
|
|
|
- items = None
|
|
|
-
|
|
|
+ data = json.loads(content)
|
|
|
+ items = _extract_items(data)
|
|
|
if items is not None:
|
|
|
with open(jsonl_path, "w", encoding="utf-8") as out:
|
|
|
for item in items:
|
|
|
- out.write(_json.dumps(item, ensure_ascii=False) + "\n")
|
|
|
- # 只有当原始文件与新文件不同时才删除(避免删除刚写入的文件)
|
|
|
+ out.write(json.dumps(item, ensure_ascii=False) + "\n")
|
|
|
if jsonl_path != file_path and file_path.exists():
|
|
|
file_path.unlink()
|
|
|
return jsonl_path
|
|
|
- except _json.JSONDecodeError:
|
|
|
+ except json.JSONDecodeError:
|
|
|
pass
|
|
|
|
|
|
- # 不是标准 JSON,可能是 JSONL,逐行验证
|
|
|
- lines = content.split("\n")
|
|
|
- valid_lines = []
|
|
|
- for line in lines:
|
|
|
- line = line.strip()
|
|
|
- if line:
|
|
|
- try:
|
|
|
- _json.loads(line)
|
|
|
- valid_lines.append(line)
|
|
|
- except _json.JSONDecodeError:
|
|
|
- continue # 跳过无效行
|
|
|
+ # JSONL 逐行验证
|
|
|
+ valid_lines = [line.strip() for line in content.split("\n") if line.strip()]
|
|
|
+ valid_lines = [line for line in valid_lines if _is_valid_json(line)]
|
|
|
|
|
|
with open(jsonl_path, "w", encoding="utf-8") as out:
|
|
|
out.write("\n".join(valid_lines) + ("\n" if valid_lines else ""))
|
|
|
@@ -466,24 +394,28 @@ def _convert_to_jsonl(file_path: Path) -> Path:
|
|
|
return jsonl_path
|
|
|
|
|
|
|
|
|
-def _detect_format(filename: str) -> str:
|
|
|
- """根据文件名推断格式。"""
|
|
|
- name = filename.lower()
|
|
|
- if name.endswith(".jsonl"):
|
|
|
- return "jsonl"
|
|
|
- if name.endswith(".csv"):
|
|
|
- return "csv"
|
|
|
- if name.endswith(".parquet"):
|
|
|
- return "parquet"
|
|
|
- if name.endswith(".json"):
|
|
|
- return "json"
|
|
|
- return "json"
|
|
|
+def _extract_items(data) -> list | None:
|
|
|
+ """从 JSON 数据中提取记录列表。"""
|
|
|
+ if isinstance(data, list):
|
|
|
+ return data
|
|
|
+ if isinstance(data, dict):
|
|
|
+ for key in ("data", "items", "results", "records", "annotations", "samples"):
|
|
|
+ if key in data and isinstance(data[key], list):
|
|
|
+ return data[key]
|
|
|
+ return [data]
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+def _is_valid_json(s: str) -> bool:
|
|
|
+ try:
|
|
|
+ json.loads(s)
|
|
|
+ return True
|
|
|
+ except json.JSONDecodeError:
|
|
|
+ return False
|
|
|
|
|
|
|
|
|
def _count_records(file_path: Path, fmt: str) -> int:
|
|
|
"""计算文件中的记录数。"""
|
|
|
- import json
|
|
|
-
|
|
|
if not file_path.exists():
|
|
|
return 0
|
|
|
|
|
|
@@ -494,9 +426,7 @@ def _count_records(file_path: Path, fmt: str) -> int:
|
|
|
elif fmt == "json":
|
|
|
with open(file_path, "r", encoding="utf-8") as f:
|
|
|
data = json.load(f)
|
|
|
- if isinstance(data, list):
|
|
|
- return len(data)
|
|
|
- return 1
|
|
|
+ return len(data) if isinstance(data, list) else 1
|
|
|
elif fmt == "csv":
|
|
|
import csv
|
|
|
with open(file_path, "r", encoding="utf-8") as f:
|