test_ms_api.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. #!/usr/bin/env python3
  2. """测试 MsDataset.load() 能否正确下载图片数据集。"""
  3. import sys
  4. import json
  5. dataset_id = sys.argv[1] if len(sys.argv) > 1 else "tany0699/carBrands50"
  6. namespace, ds_name = dataset_id.split("/", 1) if "/" in dataset_id else ("", dataset_id)
  7. print(f"测试数据集: {dataset_id}")
  8. print(f"namespace: {namespace}, name: {ds_name}\n")
  9. print("=== 用 MsDataset.load() 下载 ===")
  10. try:
  11. from modelscope.msdatasets import MsDataset
  12. ds = None
  13. for split in ("train", "validation", "test"):
  14. try:
  15. if namespace:
  16. ds = MsDataset.load(ds_name, namespace=namespace, split=split)
  17. else:
  18. ds = MsDataset.load(dataset_id, split=split)
  19. if ds:
  20. print(f"加载 split='{split}' 成功, 共 {len(ds) if hasattr(ds, '__len__') else '?'} 条")
  21. break
  22. except Exception as e:
  23. print(f"split='{split}' 失败: {e}")
  24. if not ds:
  25. try:
  26. if namespace:
  27. ds = MsDataset.load(ds_name, namespace=namespace)
  28. else:
  29. ds = MsDataset.load(dataset_id)
  30. print(f"不带 split 加载成功, 类型: {type(ds)}")
  31. except Exception as e:
  32. print(f"不带 split 也失败: {e}")
  33. sys.exit(1)
  34. if not hasattr(ds, "__iter__"):
  35. print(f"数据集不可迭代, 类型: {type(ds)}")
  36. sys.exit(1)
  37. # 查看前 3 条数据
  38. print("\n=== 前 3 条数据 ===")
  39. count = 0
  40. for row in ds:
  41. if count >= 3:
  42. break
  43. print(f"\n--- Record {count} ---")
  44. for k, v in row.items():
  45. vtype = type(v).__name__
  46. if vtype == "Image":
  47. print(f" {k}: PIL.Image (size={v.size}, mode={v.mode})")
  48. elif isinstance(v, str) and len(v) > 100:
  49. print(f" {k}: str (len={len(v)}) '{v[:100]}...'")
  50. else:
  51. print(f" {k}: {vtype} = {v}")
  52. count += 1
  53. print(f"\n=== 完成 ===")
  54. except Exception as e:
  55. print(f"失败: {e}")
  56. import traceback
  57. traceback.print_exc()