test_ms_api.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. #!/usr/bin/env python3
  2. """测试 MsDataset.load() 完整流程(含图片复制)。"""
  3. import json
  4. import sys
  5. import os
  6. import shutil
  7. from pathlib import Path
  8. dataset_id = sys.argv[1] if len(sys.argv) > 1 else "tany0699/carBrands50"
  9. namespace, ds_name = dataset_id.split("/", 1)
  10. print(f"数据集: {dataset_id}\n")
  11. # 清理旧数据
  12. ds_dir = Path("/tmp/ms_test_download")
  13. if ds_dir.exists():
  14. shutil.rmtree(ds_dir)
  15. ds_dir.mkdir(parents=True)
  16. images_dir = ds_dir / "images"
  17. print("=== 用 MsDataset.load() 下载 ===")
  18. from modelscope.msdatasets import MsDataset
  19. from PIL import Image
  20. ds = None
  21. for split in ("train", "validation", "test"):
  22. try:
  23. if namespace:
  24. ds = MsDataset.load(ds_name, namespace=namespace, split=split)
  25. else:
  26. ds = MsDataset.load(dataset_id, split=split)
  27. if ds:
  28. count = len(ds) if hasattr(ds, "__len__") else "?"
  29. print(f"split='{split}' 成功, 共 {count} 条")
  30. break
  31. except Exception as e:
  32. print(f"split='{split}' 失败: {e}")
  33. if not ds:
  34. print("所有 split 都失败")
  35. sys.exit(1)
  36. # 处理前 5 条数据
  37. print("\n=== 处理前 5 条数据(复制图片) ===")
  38. records = []
  39. img_counter = 0
  40. count = 0
  41. for row in ds:
  42. if count >= 5:
  43. break
  44. record = {}
  45. for k, v in row.items():
  46. # 检查是否是 PIL.Image 对象
  47. if isinstance(v, Image.Image):
  48. images_dir.mkdir(parents=True, exist_ok=True)
  49. img_name = f"{img_counter:06d}.jpg"
  50. img_path = images_dir / img_name
  51. if v.mode in ("RGBA", "P", "LA"):
  52. v = v.convert("RGB")
  53. v.save(str(img_path), format="JPEG", quality=90)
  54. record[k] = f"images/{img_name}"
  55. print(f"Record {count}: {k} -> PIL.Image saved as {img_name}")
  56. img_counter += 1
  57. # 检查是否是图片文件路径
  58. elif isinstance(v, str) and v.lower().endswith(('.jpg', '.jpeg', '.png', '.gif', '.bmp')):
  59. if os.path.isabs(v) and os.path.exists(v):
  60. images_dir.mkdir(parents=True, exist_ok=True)
  61. ext = os.path.splitext(v)[1]
  62. img_name = f"{img_counter:06d}{ext}"
  63. dest_path = images_dir / img_name
  64. try:
  65. shutil.copy2(v, dest_path)
  66. record[k] = f"images/{img_name}"
  67. size = os.path.getsize(dest_path)
  68. print(f"Record {count}: {k} -> copied {os.path.basename(v)} ({size} bytes) as {img_name}")
  69. img_counter += 1
  70. except Exception as e:
  71. print(f"Record {count}: {k} -> failed to copy: {e}")
  72. record[k] = v
  73. else:
  74. record[k] = v
  75. print(f"Record {count}: {k} -> {v} (relative path)")
  76. else:
  77. record[k] = v
  78. records.append(record)
  79. count += 1
  80. # 写入 JSONL
  81. print("\n=== 写入 JSONL ===")
  82. jsonl_path = ds_dir / "data.jsonl"
  83. with open(jsonl_path, "w", encoding="utf-8") as f:
  84. for item in records:
  85. f.write(json.dumps(item, ensure_ascii=False) + "\n")
  86. print(f"写入 {len(records)} 条记录到 {jsonl_path}")
  87. # 显示结果
  88. print("\n=== JSONL 内容 ===")
  89. with open(jsonl_path, "r", encoding="utf-8") as f:
  90. for i, line in enumerate(f):
  91. print(f"{i}: {line.strip()}")
  92. # 显示 images 目录
  93. print(f"\n=== images 目录 ===")
  94. if images_dir.exists():
  95. for img_file in sorted(images_dir.iterdir()):
  96. size = img_file.stat().st_size
  97. print(f" {img_file.name} ({size} bytes)")
  98. else:
  99. print(" (空)")
  100. print(f"\n=== 测试完成! ===")
  101. print(f"数据集目录: {ds_dir}")