test_ms_api.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061
  1. #!/usr/bin/env python3
  2. """测试 MsDataset.load() 完整下载流程。"""
  3. import json
  4. import sys
  5. import os
  6. import tempfile
  7. from pathlib import Path
  8. dataset_id = sys.argv[1] if len(sys.argv) > 1 else "tany0699/carBrands50"
  9. namespace, ds_name = dataset_id.split("/", 1)
  10. print(f"数据集: {dataset_id}\n")
  11. # 测试 MsDataset.load()
  12. print("=== 用 MsDataset.load() 下载 ===")
  13. from modelscope.msdatasets import MsDataset
  14. from PIL import Image
  15. ds = None
  16. for split in ("train", "validation", "test"):
  17. try:
  18. if namespace:
  19. ds = MsDataset.load(ds_name, namespace=namespace, split=split)
  20. else:
  21. ds = MsDataset.load(dataset_id, split=split)
  22. if ds:
  23. count = len(ds) if hasattr(ds, "__len__") else "?"
  24. print(f"split='{split}' 成功, 共 {count} 条")
  25. break
  26. except Exception as e:
  27. print(f"split='{split}' 失败: {e}")
  28. if not ds:
  29. print("所有 split 都失败")
  30. sys.exit(1)
  31. # 查看前 2 条数据
  32. print("\n=== 前 2 条数据 ===")
  33. count = 0
  34. for row in ds:
  35. if count >= 2:
  36. break
  37. print(f"\n--- Record {count} ---")
  38. for k, v in row.items():
  39. vtype = type(v).__name__
  40. if isinstance(v, Image.Image):
  41. # 模拟保存
  42. tmp = tempfile.NamedTemporaryFile(suffix=".jpg", delete=False)
  43. if v.mode in ("RGBA", "P", "LA"):
  44. v = v.convert("RGB")
  45. v.save(tmp.name, format="JPEG", quality=90)
  46. size = os.path.getsize(tmp.name)
  47. os.unlink(tmp.name)
  48. print(f" {k}: PIL.Image ({v.size[0]}x{v.size[1]}, mode={v.mode}) -> saved as {size} bytes")
  49. elif isinstance(v, str) and len(v) > 100:
  50. print(f" {k}: str (len={len(v)})")
  51. else:
  52. print(f" {k}: {vtype} = {v}")
  53. count += 1
  54. print(f"\n=== 测试通过! ===")