test_attrs.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. import pytest
  2. from gpustack.schemas import ModelInstance
  3. from gpustack.schemas.models import (
  4. DistributedServers,
  5. ModelInstanceSubordinateWorker,
  6. ModelInstanceStateEnum,
  7. )
  8. from gpustack.utils.attrs import get_attr, set_attr
  9. @pytest.mark.parametrize(
  10. "o, path, expected",
  11. [
  12. # Dict access
  13. (
  14. {"a": {"b": {"c": 42}}},
  15. "a.b.c",
  16. 42,
  17. ),
  18. # Dict access with list index
  19. (
  20. {"a": [{"b": {"c": 42}}]},
  21. "a.0.b.c",
  22. 42,
  23. ),
  24. # None access
  25. (
  26. None,
  27. "a.b.c",
  28. None,
  29. ),
  30. # Dict access with on-existent path
  31. (
  32. {"a": {"b": {"c": 42}}},
  33. "a.b.d",
  34. None,
  35. ),
  36. # List access
  37. (
  38. [1, 2, 3],
  39. "0",
  40. 1,
  41. ),
  42. # List of dicts access
  43. (
  44. [{"a": 1}, {"b": 2}],
  45. "0.a",
  46. 1,
  47. ),
  48. # Complex object access
  49. (
  50. ModelInstance(
  51. distributed_servers=DistributedServers(
  52. subordinate_workers=[
  53. ModelInstanceSubordinateWorker(
  54. worker_ip="192.168.50.3",
  55. ),
  56. ],
  57. ),
  58. ),
  59. "distributed_servers.subordinate_workers.0.worker_ip",
  60. "192.168.50.3",
  61. ),
  62. # Complex object access with non-existent path
  63. (
  64. ModelInstance(
  65. distributed_servers=DistributedServers(
  66. subordinate_workers=[
  67. ModelInstanceSubordinateWorker(
  68. worker_ip="192.168.50.3",
  69. ),
  70. ],
  71. ),
  72. ),
  73. "distributed_servers.subordinate_workers.0.name",
  74. None,
  75. ),
  76. ],
  77. )
  78. @pytest.mark.unit
  79. def test_get_attr(o, path, expected):
  80. actual = get_attr(o, path)
  81. assert (
  82. actual == expected
  83. ), f"Expected {expected} but got {actual} for path '{path}' in object {o}"
  84. @pytest.mark.parametrize(
  85. "o, path, value, expected",
  86. [
  87. # Dict access
  88. (
  89. {"a": {"b": {"c": 42}}},
  90. "a.b.c",
  91. 100,
  92. {"a": {"b": {"c": 100}}},
  93. ),
  94. # Dict access with list index
  95. (
  96. {"a": [{"b": {"c": 42}}]},
  97. "a.0.b.c",
  98. 100,
  99. {"a": [{"b": {"c": 100}}]},
  100. ),
  101. # None access
  102. (
  103. None,
  104. "a.b.c",
  105. 100,
  106. None,
  107. ),
  108. # Dict access with non-existent path: insert new item
  109. (
  110. {"a": {"b": {"c": 42}}},
  111. "a.b.d",
  112. 100,
  113. {"a": {"b": {"c": 42, "d": 100}}},
  114. ),
  115. # Dict access with non-existent path: nothing to do
  116. (
  117. {"a": {"b": {"c": 42}}},
  118. "a.d.c",
  119. 100,
  120. {"a": {"b": {"c": 42}}},
  121. ),
  122. # List access
  123. (
  124. [1, 2, 3],
  125. "0",
  126. 100,
  127. [100, 2, 3],
  128. ),
  129. # List of dicts access
  130. (
  131. [{"a": 1}, {"b": 2}],
  132. "0.a",
  133. 100,
  134. [{"a": 100}, {"b": 2}],
  135. ),
  136. # Complex object access
  137. (
  138. ModelInstance(
  139. distributed_servers=DistributedServers(
  140. subordinate_workers=[
  141. ModelInstanceSubordinateWorker(
  142. worker_ip="192.168.50.3",
  143. state=ModelInstanceStateEnum.RUNNING,
  144. ),
  145. ModelInstanceSubordinateWorker(
  146. worker_ip="192.168.50.5",
  147. state=ModelInstanceStateEnum.ERROR,
  148. ),
  149. ],
  150. ),
  151. ),
  152. "distributed_servers.subordinate_workers.0.worker_ip",
  153. "192.168.50.4",
  154. ModelInstance(
  155. distributed_servers=DistributedServers(
  156. subordinate_workers=[
  157. ModelInstanceSubordinateWorker(
  158. worker_ip="192.168.50.4",
  159. state=ModelInstanceStateEnum.RUNNING,
  160. ),
  161. ModelInstanceSubordinateWorker(
  162. worker_ip="192.168.50.5",
  163. state=ModelInstanceStateEnum.ERROR,
  164. ),
  165. ],
  166. ),
  167. ),
  168. ),
  169. # Complex object access: replace an item
  170. (
  171. ModelInstance(
  172. distributed_servers=DistributedServers(
  173. subordinate_workers=[
  174. ModelInstanceSubordinateWorker(
  175. worker_ip="192.168.50.3",
  176. state=ModelInstanceStateEnum.RUNNING,
  177. ),
  178. ModelInstanceSubordinateWorker(
  179. worker_ip="192.168.50.5",
  180. state=ModelInstanceStateEnum.ERROR,
  181. ),
  182. ],
  183. ),
  184. ),
  185. "distributed_servers.subordinate_workers.-1",
  186. ModelInstanceSubordinateWorker(
  187. worker_ip="192.168.50.4",
  188. ),
  189. ModelInstance(
  190. distributed_servers=DistributedServers(
  191. subordinate_workers=[
  192. ModelInstanceSubordinateWorker(
  193. worker_ip="192.168.50.3",
  194. state=ModelInstanceStateEnum.RUNNING,
  195. ),
  196. ModelInstanceSubordinateWorker(
  197. worker_ip="192.168.50.4",
  198. ),
  199. ],
  200. ),
  201. ),
  202. ),
  203. # Complex object access with non-existent path: insert new item
  204. (
  205. ModelInstance(
  206. distributed_servers=DistributedServers(
  207. subordinate_workers=[
  208. ModelInstanceSubordinateWorker(
  209. worker_ip="192.168.50.3",
  210. state=ModelInstanceStateEnum.RUNNING,
  211. ),
  212. ModelInstanceSubordinateWorker(
  213. worker_ip="192.168.50.5",
  214. state=ModelInstanceStateEnum.ERROR,
  215. ),
  216. ],
  217. ),
  218. ),
  219. "distributed_servers.subordinate_workers.2",
  220. ModelInstanceSubordinateWorker(
  221. worker_ip="192.168.50.4",
  222. ),
  223. ModelInstance(
  224. distributed_servers=DistributedServers(
  225. subordinate_workers=[
  226. ModelInstanceSubordinateWorker(
  227. worker_ip="192.168.50.3",
  228. state=ModelInstanceStateEnum.RUNNING,
  229. ),
  230. ModelInstanceSubordinateWorker(
  231. worker_ip="192.168.50.5",
  232. state=ModelInstanceStateEnum.ERROR,
  233. ),
  234. ModelInstanceSubordinateWorker(
  235. worker_ip="192.168.50.4",
  236. ),
  237. ],
  238. ),
  239. ),
  240. ),
  241. # Complex object access with non-existent path: nothing to do
  242. (
  243. ModelInstance(
  244. distributed_servers=DistributedServers(
  245. subordinate_workers=[
  246. ModelInstanceSubordinateWorker(
  247. worker_ip="192.168.50.3",
  248. state=ModelInstanceStateEnum.RUNNING,
  249. ),
  250. ModelInstanceSubordinateWorker(
  251. worker_ip="192.168.50.5",
  252. state=ModelInstanceStateEnum.ERROR,
  253. ),
  254. ],
  255. ),
  256. ),
  257. "distributed_servers.subordinate_workers.0.name",
  258. "test",
  259. ModelInstance(
  260. distributed_servers=DistributedServers(
  261. subordinate_workers=[
  262. ModelInstanceSubordinateWorker(
  263. worker_ip="192.168.50.3",
  264. state=ModelInstanceStateEnum.RUNNING,
  265. ),
  266. ModelInstanceSubordinateWorker(
  267. worker_ip="192.168.50.5",
  268. state=ModelInstanceStateEnum.ERROR,
  269. ),
  270. ],
  271. ),
  272. ),
  273. ),
  274. ],
  275. )
  276. @pytest.mark.unit
  277. def test_set_attr(o, path, value, expected):
  278. set_attr(o, path, value)
  279. actual = o
  280. assert (
  281. actual == expected
  282. ), f"Expected {expected} but got {actual} for path '{path}' in object {o}"