lxylxy123321 1 неделя назад
Родитель
Сommit
af0c2dfdb8
1 измененных файлов с 53 добавлено и 23 удалено
  1. 53 23
      backend/app/services/dataset_service.py

+ 53 - 23
backend/app/services/dataset_service.py

@@ -150,6 +150,53 @@ async def upload_dataset(file: UploadFile) -> dict[str, Any]:
     }
 
 
+def _format_value(value) -> str:
+    """将复杂值格式化为可读字符串。"""
+    if isinstance(value, (dict, list)):
+        return json.dumps(value, ensure_ascii=False, indent=2)
+    return str(value)
+
+
+def _is_sharegpt_format(records: list[dict]) -> bool:
+    """检测是否为 ShareGPT 格式。"""
+    if not records:
+        return False
+    first = records[0]
+    if "conversations" in first and isinstance(first["conversations"], list):
+        if len(first["conversations"]) > 0 and isinstance(first["conversations"][0], dict):
+            conv = first["conversations"][0]
+            return "from" in conv and "value" in conv
+    return False
+
+
+def _flatten_sharegpt(records: list[dict]) -> tuple[list[dict], list[str]]:
+    """将 ShareGPT 格式展平为 input/output 列。"""
+    flat_rows = []
+    for row in records:
+        conversations = row.get("conversations", [])
+        # 每轮 user+assistant 对话作为一行
+        for i in range(0, len(conversations) - 1, 2):
+            user_turn = conversations[i]
+            assistant_turn = conversations[i + 1] if i + 1 < len(conversations) else None
+
+            if user_turn.get("from") in ("human", "user"):
+                input_text = str(user_turn.get("value", ""))
+                output_text = str(assistant_turn.get("value", "")) if assistant_turn else ""
+            else:
+                input_text = str(assistant_turn.get("value", "")) if assistant_turn else ""
+                output_text = str(user_turn.get("value", ""))
+
+            # 截断过长文本
+            if len(input_text) > 500:
+                input_text = input_text[:500] + "..."
+            if len(output_text) > 500:
+                output_text = output_text[:500] + "..."
+
+            flat_rows.append({"input": input_text, "output": output_text})
+
+    return flat_rows, ["input", "output"]
+
+
 async def preview_dataset(dataset_id: str, rows: int = 10) -> dict[str, Any]:
     """预览数据集前 N 行。"""
     async with async_session() as session:
@@ -166,7 +213,12 @@ async def preview_dataset(dataset_id: str, rows: int = 10) -> dict[str, Any]:
 
     fmt = record.format
     preview_data = _read_records(file_path, fmt, rows)
-    columns = list(preview_data[0].keys()) if preview_data else []
+
+    # 检测是否为 ShareGPT 格式,如果是则展平为 input/output 列
+    if _is_sharegpt_format(preview_data):
+        preview_data, columns = _flatten_sharegpt(preview_data)
+    else:
+        columns = list(preview_data[0].keys()) if preview_data else []
 
     return {
         "total_records": record.record_count,
@@ -287,28 +339,6 @@ def _count_records(file_path: Path, fmt: str) -> int:
     return 0
 
 
-def _format_value(value) -> str:
-    """将复杂值格式化为可读字符串,特别处理 ShareGPT 格式的 conversations 数组。"""
-    if isinstance(value, list) and len(value) > 0 and isinstance(value[0], dict):
-        # 检测 ShareGPT 格式:[{"from": "human", "value": "..."}, {"from": "gpt", "value": "..."}]
-        first = value[0]
-        if "from" in first and "value" in first:
-            parts = []
-            for turn in value:
-                role = turn.get("from", "unknown")
-                text = str(turn.get("value", ""))
-                # 截断过长文本
-                if len(text) > 200:
-                    text = text[:200] + "..."
-                parts.append(f"[{role}] {text}")
-            return "\n---\n".join(parts)
-        # 其他对象数组:显示为 JSON
-        return json.dumps(value, ensure_ascii=False, indent=2)
-    if isinstance(value, (dict, list)):
-        return json.dumps(value, ensure_ascii=False, indent=2)
-    return str(value)
-
-
 def _read_records(file_path: Path, fmt: str, n: int) -> list[dict]:
     if fmt == "jsonl":
         records = []