ソースを参照

修复标注平台报错

lxylxy123321 1 時間 前
コミット
9a34a66925
1 ファイル変更67 行追加22 行削除
  1. 67 22
      backend/app/services/annotation_platform_service.py

+ 67 - 22
backend/app/services/annotation_platform_service.py

@@ -149,8 +149,11 @@ def _auth_headers() -> dict[str, str]:
     }
 
 
-async def _request(method: str, path: str, **kwargs) -> dict[str, Any]:
-    """统一的业务请求方法,自动携带 Token。"""
+async def _request(method: str, path: str, raise_on_error: bool = True, **kwargs) -> dict[str, Any]:
+    """统一的业务请求方法,自动携带 Token。
+
+    raise_on_error=False 时,400 等错误不抛异常,返回原始响应体。
+    """
     await get_token()
     base_url = _get_base_url()
 
@@ -161,6 +164,14 @@ async def _request(method: str, path: str, **kwargs) -> dict[str, Any]:
             headers=_auth_headers(),
             **kwargs,
         )
+
+        if not raise_on_error and resp.status_code >= 400:
+            try:
+                body = resp.json()
+            except Exception:
+                body = {"status_code": resp.status_code, "text": resp.text}
+            return {"_error": True, "_status_code": resp.status_code, **body}
+
         resp.raise_for_status()
         body = resp.json()
 
@@ -214,50 +225,83 @@ async def import_project_dataset(
     """导出并下载项目数据集,保存到本地并写入数据库。
 
     流程:
-    1. POST /api/v1/open/projects/{project_id}/datasets/download → 获取 file_url
-    2. GET /api/v1/open/datasets/downloads/{download_token} → 下载文件
-    3. 保存到 uploads 目录
-    4. 写入 DatasetRecord 数据库
+    1. 查询项目详情获取 task_type
+    2. POST /api/v1/open/projects/{project_id}/datasets/download → 获取 file_url
+    3. GET /api/v1/open/datasets/downloads/{download_token} → 下载文件
+    4. 保存到 uploads 目录
+    5. 写入 DatasetRecord 数据库
     """
-    # 尝试多种格式导出,某些项目可能不兼容 alpaca 格式
-    formats_to_try = [format]
-    if format != "json":
-        formats_to_try.append("json")
-    if format not in ("raw", "original"):
-        formats_to_try.append("raw")
+    # 1. 查询项目详情,获取 task_type 和 project_type
+    try:
+        project_detail = await get_project_detail(project_id)
+        task_type = project_detail.get("task_type", "")
+        project_type = project_detail.get("project_type", "text")
+        logger.info(f"Project {project_id}: task_type={task_type}, project_type={project_type}")
+    except Exception as e:
+        logger.warning(f"Failed to get project detail: {e}, using default formats")
+        task_type = ""
+        project_type = "text"
+
+    # 2. 根据项目类型选择导出格式
+    # 文本项目: alpaca / sharegpt
+    # 图片项目: json / csv / coco / yolo / pascal_voc
+    if project_type == "text":
+        formats_to_try = ["alpaca", "sharegpt"]
+    else:
+        formats_to_try = ["json", "csv", "coco"]
+
+    # 如果用户指定了格式,优先使用
+    if format and format not in formats_to_try:
+        formats_to_try.insert(0, format)
 
     file_content = b""
     file_name = ""
     total_exported = 0
     used_format = ""
+    last_error = ""
 
     for try_format in formats_to_try:
-        # 1. 请求导出
+        # 请求导出
         export_data = await _request(
             "POST",
             f"/api/v1/open/projects/{project_id}/datasets/download",
+            raise_on_error=False,
             json={"format": try_format, "completed_only": True},
         )
 
-        logger.info(
-            f"Annotation export response (format={try_format}): {export_data}"
-        )
+        # 检查是否返回错误
+        if export_data.get("_error"):
+            status_code = export_data.get("_status_code", 0)
+            error_msg = export_data.get("message", export_data)
+            last_error = f"HTTP {status_code}: {error_msg}"
+            logger.warning(f"Format '{try_format}' failed: {last_error}, trying next...")
+            continue
+
+        logger.info(f"Annotation export response (format={try_format}): {export_data}")
 
         file_url = export_data.get("file_url", "")
         file_name = export_data.get("file_name", f"{project_id}_{try_format}.json")
         total_exported = export_data.get("total_exported", 0)
+        export_status = export_data.get("status", "completed")
+
+        # 检查导出状态
+        if export_status != "completed":
+            logger.warning(f"Export status={export_status} for format={try_format}, trying next...")
+            last_error = f"导出状态: {export_status}"
+            continue
 
         if not file_url:
             logger.warning(f"No file_url for format={try_format}, trying next...")
+            last_error = "未返回下载链接"
             continue
 
-        # 2. 从 file_url 中提取 download_token
+        # 从 file_url 中提取 download_token
         if "/datasets/downloads/" in file_url:
             download_token = file_url.split("/datasets/downloads/")[-1].strip("/")
         else:
             download_token = file_url.rstrip("/").split("/")[-1]
 
-        # 3. 下载文件,带轮询(标注平台生成文件可能需要时间)
+        # 下载文件,带轮询(标注平台生成文件可能需要时间)
         await get_token()
         base_url = _get_base_url()
         download_url = f"{base_url}/api/v1/open/datasets/downloads/{download_token}"
@@ -309,15 +353,16 @@ async def import_project_dataset(
             f"Format '{try_format}' returned empty file ({len(file_content)} bytes), "
             f"trying next format..."
         )
+        last_error = f"格式 {try_format} 导出文件为空"
 
     if len(file_content) <= 10:
         logger.warning(f"Annotation file content: {file_content!r}")
         raise RuntimeError(
-            f"标注平台导出文件为空(尝试了格式: {formats_to_try}),"
-            f"total_exported={total_exported},请检查标注平台项目状态"
+            f"标注平台导出文件为空(task_type={task_type}, 尝试了格式: {formats_to_try}),"
+            f"total_exported={total_exported}。最后错误: {last_error}"
         )
 
-    # 4. 保存到 uploads 目录
+    # 保存到 uploads 目录
     upload_dir = settings.uploads_dir
     upload_dir.mkdir(parents=True, exist_ok=True)
 
@@ -328,7 +373,7 @@ async def import_project_dataset(
 
     file_path.write_bytes(file_content)
 
-    # 5. 统一转为 JSONL 格式
+    # 统一转为 JSONL 格式
     jsonl_path = _convert_to_jsonl(file_path)
     record_count = _count_records(jsonl_path, "jsonl")