Pārlūkot izejas kodu

增加图片预览

lxylxy123321 17 stundas atpakaļ
vecāks
revīzija
5a9c471f2f

+ 58 - 9
backend/app/services/dataset_service.py

@@ -381,6 +381,44 @@ async def upload_dataset(file: UploadFile) -> dict[str, Any]:
     }
 
 
+def _resolve_image_path(path_str: str, data_dir: Path) -> Path | None:
+    """解析图片路径,返回绝对路径。"""
+    if not path_str:
+        return None
+    p = Path(path_str)
+    if p.is_absolute() and p.exists():
+        return p
+    # 相对路径:相对于数据目录
+    candidate = data_dir / p
+    if candidate.exists():
+        return candidate
+    # 也可能在 data_dir 的子目录中
+    for child in data_dir.rglob(p.name):
+        if child.is_file():
+            return child
+    return None
+
+
+def _encode_image_base64(image_path: Path, max_size: int = 200) -> str | None:
+    """将图片转为 base64 data URI,用于前端预览。"""
+    import base64
+
+    try:
+        from PIL import Image
+        img = Image.open(image_path)
+        # 缩小尺寸用于预览
+        img.thumbnail((max_size, max_size))
+        if img.mode in ("RGBA", "P", "LA"):
+            img = img.convert("RGB")
+        import io
+        buf = io.BytesIO()
+        img.save(buf, format="JPEG", quality=85)
+        b64 = base64.b64encode(buf.getvalue()).decode("ascii")
+        return f"data:image/jpeg;base64,{b64}"
+    except Exception:
+        return None
+
+
 def _format_value(value) -> str:
     """将复杂值格式化为可读字符串。"""
     if isinstance(value, (dict, list)):
@@ -434,11 +472,11 @@ async def preview_dataset(dataset_id: str, rows: int = 10) -> dict[str, Any]:
         result = await session.execute(select(DatasetRecord).where(DatasetRecord.id == dataset_id))
         record = result.scalar_one_or_none()
         if not record:
-            return {"total_records": 0, "preview_rows": [], "columns": []}
+            return {"total_records": 0, "preview_rows": [], "columns": [], "image_column": None}
 
     file_path = Path(record.file_path)
     if not file_path.exists():
-        return {"total_records": 0, "preview_rows": [], "columns": []}
+        return {"total_records": 0, "preview_rows": [], "columns": [], "image_column": None}
 
     fmt = record.format
     preview_data = _read_records(file_path, fmt, rows)
@@ -449,16 +487,27 @@ async def preview_dataset(dataset_id: str, rows: int = 10) -> dict[str, Any]:
     else:
         columns = list(preview_data[0].keys()) if preview_data else []
 
+    # 检测是否为视觉数据集(有 image_path 列),将图片转为 base64 嵌入预览
+    image_column = "image_path" if "image_path" in columns else None
+    data_dir = file_path.parent
+
+    preview_rows = []
+    for i, row in enumerate(preview_data):
+        data = {}
+        for k, v in row.items():
+            if k == "image_path" and image_column:
+                # 解析图片路径,转为 base64 嵌入
+                img_path = _resolve_image_path(str(v), data_dir)
+                data[k] = _encode_image_base64(img_path) if img_path else None
+            else:
+                data[k] = _format_value(v)
+        preview_rows.append({"row_index": i, "data": data})
+
     return {
         "total_records": record.record_count,
-        "preview_rows": [
-            {
-                "row_index": i,
-                "data": {k: _format_value(v) for k, v in row.items()},
-            }
-            for i, row in enumerate(preview_data)
-        ],
+        "preview_rows": preview_rows,
         "columns": columns,
+        "image_column": image_column,
     }
 
 

+ 1 - 0
frontend/src/api/client.ts

@@ -305,6 +305,7 @@ interface DatasetPreview {
   total_records: number
   preview_rows: { row_index: number; data: Record<string, unknown> }[]
   columns: string[]
+  image_column: string | null
 }
 
 interface DatasetValidation {

+ 25 - 5
frontend/src/pages/Datasets.tsx

@@ -121,7 +121,7 @@ export function Datasets() {
   const [uploading, setUploading] = useState(false)
   const [downloading, setDownloading] = useState(false)
   const [loading, setLoading] = useState(false)
-  const [previewData, setPreviewData] = useState<{ columns: string[]; rows: { row_index: number; data: Record<string, unknown> }[] } | null>(null)
+  const [previewData, setPreviewData] = useState<{ columns: string[]; rows: { row_index: number; data: Record<string, unknown> }[]; image_column: string | null } | null>(null)
   const [previewError, setPreviewError] = useState<string | null>(null)
   const inputRef = useRef<HTMLInputElement>(null)
 
@@ -247,7 +247,7 @@ export function Datasets() {
           setPreviewData(null)
           return
         }
-        setPreviewData({ columns: res.columns, rows: res.preview_rows })
+        setPreviewData({ columns: res.columns, rows: res.preview_rows, image_column: res.image_column || null })
       })
       .catch(err => {
         setPreviewError(`预览失败: ${err.message || '未知错误'}`)
@@ -793,8 +793,28 @@ export function Datasets() {
                     onMouseLeave={e => { e.currentTarget.style.background = 'transparent' }}
                   >
                     {previewData.columns.map(col => {
-                      const cellVal = String(row.data[col] ?? '')
-                      const isMultiline = cellVal.includes('\n') || cellVal.length > 100
+                      const isImage = col === previewData.image_column
+                      const cellVal = row.data[col]
+                      if (isImage && cellVal) {
+                        return (
+                          <td
+                            key={col}
+                            style={{ padding: '6px 12px', verticalAlign: 'middle' }}
+                          >
+                            <img
+                              src={cellVal as string}
+                              alt=""
+                              style={{
+                                width: 80, height: 80, objectFit: 'cover',
+                                borderRadius: 6, border: '1px solid #e2e8f0',
+                                background: '#f8fafc',
+                              }}
+                            />
+                          </td>
+                        )
+                      }
+                      const cellStr = String(cellVal ?? '')
+                      const isMultiline = cellStr.includes('\n') || cellStr.length > 100
                       return (
                         <td
                           key={col}
@@ -808,7 +828,7 @@ export function Datasets() {
                             fontSize: isMultiline ? 12 : 13,
                           }}
                         >
-                          {cellVal}
+                          {cellStr}
                         </td>
                       )
                     })}