|
|
@@ -150,6 +150,53 @@ async def upload_dataset(file: UploadFile) -> dict[str, Any]:
|
|
|
}
|
|
|
|
|
|
|
|
|
+def _format_value(value) -> str:
|
|
|
+ """将复杂值格式化为可读字符串。"""
|
|
|
+ if isinstance(value, (dict, list)):
|
|
|
+ return json.dumps(value, ensure_ascii=False, indent=2)
|
|
|
+ return str(value)
|
|
|
+
|
|
|
+
|
|
|
+def _is_sharegpt_format(records: list[dict]) -> bool:
|
|
|
+ """检测是否为 ShareGPT 格式。"""
|
|
|
+ if not records:
|
|
|
+ return False
|
|
|
+ first = records[0]
|
|
|
+ if "conversations" in first and isinstance(first["conversations"], list):
|
|
|
+ if len(first["conversations"]) > 0 and isinstance(first["conversations"][0], dict):
|
|
|
+ conv = first["conversations"][0]
|
|
|
+ return "from" in conv and "value" in conv
|
|
|
+ return False
|
|
|
+
|
|
|
+
|
|
|
+def _flatten_sharegpt(records: list[dict]) -> tuple[list[dict], list[str]]:
|
|
|
+ """将 ShareGPT 格式展平为 input/output 列。"""
|
|
|
+ flat_rows = []
|
|
|
+ for row in records:
|
|
|
+ conversations = row.get("conversations", [])
|
|
|
+ # 每轮 user+assistant 对话作为一行
|
|
|
+ for i in range(0, len(conversations) - 1, 2):
|
|
|
+ user_turn = conversations[i]
|
|
|
+ assistant_turn = conversations[i + 1] if i + 1 < len(conversations) else None
|
|
|
+
|
|
|
+ if user_turn.get("from") in ("human", "user"):
|
|
|
+ input_text = str(user_turn.get("value", ""))
|
|
|
+ output_text = str(assistant_turn.get("value", "")) if assistant_turn else ""
|
|
|
+ else:
|
|
|
+ input_text = str(assistant_turn.get("value", "")) if assistant_turn else ""
|
|
|
+ output_text = str(user_turn.get("value", ""))
|
|
|
+
|
|
|
+ # 截断过长文本
|
|
|
+ if len(input_text) > 500:
|
|
|
+ input_text = input_text[:500] + "..."
|
|
|
+ if len(output_text) > 500:
|
|
|
+ output_text = output_text[:500] + "..."
|
|
|
+
|
|
|
+ flat_rows.append({"input": input_text, "output": output_text})
|
|
|
+
|
|
|
+ return flat_rows, ["input", "output"]
|
|
|
+
|
|
|
+
|
|
|
async def preview_dataset(dataset_id: str, rows: int = 10) -> dict[str, Any]:
|
|
|
"""预览数据集前 N 行。"""
|
|
|
async with async_session() as session:
|
|
|
@@ -166,7 +213,12 @@ async def preview_dataset(dataset_id: str, rows: int = 10) -> dict[str, Any]:
|
|
|
|
|
|
fmt = record.format
|
|
|
preview_data = _read_records(file_path, fmt, rows)
|
|
|
- columns = list(preview_data[0].keys()) if preview_data else []
|
|
|
+
|
|
|
+ # 检测是否为 ShareGPT 格式,如果是则展平为 input/output 列
|
|
|
+ if _is_sharegpt_format(preview_data):
|
|
|
+ preview_data, columns = _flatten_sharegpt(preview_data)
|
|
|
+ else:
|
|
|
+ columns = list(preview_data[0].keys()) if preview_data else []
|
|
|
|
|
|
return {
|
|
|
"total_records": record.record_count,
|
|
|
@@ -287,28 +339,6 @@ def _count_records(file_path: Path, fmt: str) -> int:
|
|
|
return 0
|
|
|
|
|
|
|
|
|
-def _format_value(value) -> str:
|
|
|
- """将复杂值格式化为可读字符串,特别处理 ShareGPT 格式的 conversations 数组。"""
|
|
|
- if isinstance(value, list) and len(value) > 0 and isinstance(value[0], dict):
|
|
|
- # 检测 ShareGPT 格式:[{"from": "human", "value": "..."}, {"from": "gpt", "value": "..."}]
|
|
|
- first = value[0]
|
|
|
- if "from" in first and "value" in first:
|
|
|
- parts = []
|
|
|
- for turn in value:
|
|
|
- role = turn.get("from", "unknown")
|
|
|
- text = str(turn.get("value", ""))
|
|
|
- # 截断过长文本
|
|
|
- if len(text) > 200:
|
|
|
- text = text[:200] + "..."
|
|
|
- parts.append(f"[{role}] {text}")
|
|
|
- return "\n---\n".join(parts)
|
|
|
- # 其他对象数组:显示为 JSON
|
|
|
- return json.dumps(value, ensure_ascii=False, indent=2)
|
|
|
- if isinstance(value, (dict, list)):
|
|
|
- return json.dumps(value, ensure_ascii=False, indent=2)
|
|
|
- return str(value)
|
|
|
-
|
|
|
-
|
|
|
def _read_records(file_path: Path, fmt: str, n: int) -> list[dict]:
|
|
|
if fmt == "jsonl":
|
|
|
records = []
|