Selaa lähdekoodia

修复数据集图片预览

lxylxy123321 13 tuntia sitten
vanhempi
sitoutus
97a42e907c
1 muutettua tiedostoa jossa 34 lisäystä ja 9 poistoa
  1. 34 9
      backend/app/services/dataset_service.py

+ 34 - 9
backend/app/services/dataset_service.py

@@ -381,21 +381,40 @@ async def upload_dataset(file: UploadFile) -> dict[str, Any]:
     }
 
 
+def _detect_image_column(columns: list[str]) -> str | None:
+    """检测哪一列是图片路径列。"""
+    candidates = ["image_path", "image", "img_path", "img", "file_path", "filename", "path", "file"]
+    for c in candidates:
+        if c in columns:
+            return c
+    # 模糊匹配:列名包含 image 或 path
+    for c in columns:
+        cl = c.lower()
+        if "image" in cl or ("path" in cl and "label" not in cl):
+            return c
+    return None
+
+
 def _resolve_image_path(path_str: str, data_dir: Path) -> Path | None:
     """解析图片路径,返回绝对路径。"""
     if not path_str:
         return None
     p = Path(path_str)
-    if p.is_absolute() and p.exists():
-        return p
-    # 相对路径:相对于数据目录
+    # 已经是绝对路径
+    if p.is_absolute():
+        return p if p.exists() else None
+    # 相对路径:先尝试相对于数据目录
     candidate = data_dir / p
     if candidate.exists():
         return candidate
-    # 也可能在 data_dir 的子目录中
+    # 也可能直接在 data_dir 下(去掉目录前缀只保留文件名)
+    if data_dir.joinpath(p.name).exists():
+        return data_dir / p.name
+    # 在 data_dir 的子目录中递归查找
     for child in data_dir.rglob(p.name):
         if child.is_file():
             return child
+    logger.debug(f"Image not found: '{path_str}' (searched in {data_dir})")
     return None
 
 
@@ -415,7 +434,8 @@ def _encode_image_base64(image_path: Path, max_size: int = 200) -> str | None:
         img.save(buf, format="JPEG", quality=85)
         b64 = base64.b64encode(buf.getvalue()).decode("ascii")
         return f"data:image/jpeg;base64,{b64}"
-    except Exception:
+    except Exception as e:
+        logger.warning(f"Failed to encode image {image_path}: {e}")
         return None
 
 
@@ -487,18 +507,23 @@ async def preview_dataset(dataset_id: str, rows: int = 10) -> dict[str, Any]:
     else:
         columns = list(preview_data[0].keys()) if preview_data else []
 
-    # 检测是否为视觉数据集(有 image_path 列),将图片转为 base64 嵌入预览
-    image_column = "image_path" if "image_path" in columns else None
+    # 检测是否为视觉数据集(有图片路径列),将图片转为 base64 嵌入预览
+    image_column = _detect_image_column(columns)
     data_dir = file_path.parent
 
     preview_rows = []
     for i, row in enumerate(preview_data):
         data = {}
         for k, v in row.items():
-            if k == "image_path" and image_column:
+            if k == image_column:
                 # 解析图片路径,转为 base64 嵌入
                 img_path = _resolve_image_path(str(v), data_dir)
-                data[k] = _encode_image_base64(img_path) if img_path else None
+                if img_path:
+                    encoded = _encode_image_base64(img_path)
+                    data[k] = encoded if encoded else str(v)
+                else:
+                    # 路径解析失败,保留原始路径文本
+                    data[k] = str(v)
             else:
                 data[k] = _format_value(v)
         preview_rows.append({"row_index": i, "data": data})