test_cache.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414
  1. import asyncio
  2. import pytest
  3. from aiocache import Cache
  4. from gpustack.logging import setup_logging
  5. from gpustack.server.cache import (
  6. build_cache_key,
  7. delete_cache_by_key,
  8. locked_cached,
  9. class_key,
  10. cache as global_cache,
  11. )
  12. setup_logging()
  13. def make_cache():
  14. return Cache(Cache.MEMORY)
  15. # ---------------------------------------------------------------------------
  16. # build_cache_key
  17. # ---------------------------------------------------------------------------
  18. class TestBuildCacheKey:
  19. def test_positional_and_keyword_produce_same_key(self):
  20. async def my_func(name: str):
  21. pass
  22. assert build_cache_key(my_func, "foo") == build_cache_key(my_func, name="foo")
  23. def test_different_args_produce_different_keys(self):
  24. async def my_func(name: str):
  25. pass
  26. assert build_cache_key(my_func, "foo") != build_cache_key(my_func, "bar")
  27. def test_multiple_params_all_equivalent_forms(self):
  28. async def my_func(a: int, b: str):
  29. pass
  30. key_all_pos = build_cache_key(my_func, 1, "x")
  31. key_all_kw = build_cache_key(my_func, a=1, b="x")
  32. key_mixed = build_cache_key(my_func, 1, b="x")
  33. assert key_all_pos == key_all_kw == key_mixed
  34. def test_key_includes_function_qualname(self):
  35. async def my_func(x: int):
  36. pass
  37. assert "my_func" in build_cache_key(my_func, 42)
  38. def test_default_args_treated_as_explicit(self):
  39. async def my_func(x: int, y: int = 10):
  40. pass
  41. assert build_cache_key(my_func, 1) == build_cache_key(my_func, 1, 10)
  42. def test_unbound_method_strips_self(self):
  43. """Unbound method (with self in sig) called without self arg produces
  44. same key as bound method called with same arg."""
  45. class MyService:
  46. async def fetch(self, name: str):
  47. pass
  48. svc = MyService()
  49. key_unbound = build_cache_key(MyService.fetch, "foo")
  50. key_bound = build_cache_key(svc.fetch, "foo")
  51. assert key_unbound == key_bound
  52. def test_kwarg_ordering_is_stable(self):
  53. """Keys are stable regardless of the order keyword arguments are passed,
  54. because bound.arguments follows declaration order not caller order."""
  55. async def my_func(_a: int, _b: str, _c: float):
  56. pass
  57. key1 = build_cache_key(my_func, _a=1, _b="x", _c=3.0)
  58. key2 = build_cache_key(my_func, _c=3.0, _a=1, _b="x")
  59. key3 = build_cache_key(my_func, _b="x", _c=3.0, _a=1)
  60. assert key1 == key2 == key3
  61. def test_fallback_for_signature_mismatch(self):
  62. """When args don't match the function signature (e.g. manual key construction
  63. with extra args), fall back to old-style string concatenation without crashing.
  64. This covers the pre-existing ModelUsageService.update() call pattern."""
  65. async def my_func(_fields: dict):
  66. pass
  67. # Passing 3 args to a 1-param function triggers the fallback
  68. key = build_cache_key(my_func, 1, 2, 3)
  69. assert "my_func" in key
  70. def test_fallback_kwargs_are_sorted(self):
  71. """Fallback path (signature mismatch) sorts kwargs for stable keys."""
  72. async def my_func(_fields: dict):
  73. pass
  74. # Wrong kwarg names trigger the fallback; order should not matter
  75. key1 = build_cache_key(my_func, z=3, a=1, m=2)
  76. key2 = build_cache_key(my_func, a=1, m=2, z=3)
  77. assert key1 == key2
  78. # ---------------------------------------------------------------------------
  79. # locked_cached decorator
  80. # ---------------------------------------------------------------------------
  81. class TestLockedCached:
  82. @pytest.mark.asyncio
  83. async def test_result_is_cached_on_second_call(self):
  84. call_count = 0
  85. test_cache = make_cache()
  86. class MyService:
  87. @locked_cached(cache=test_cache)
  88. async def fetch(self, name: str):
  89. nonlocal call_count
  90. call_count += 1
  91. return f"result-{name}"
  92. svc = MyService()
  93. r1 = await svc.fetch("foo")
  94. r2 = await svc.fetch("foo")
  95. assert r1 == r2 == "result-foo"
  96. assert call_count == 1
  97. @pytest.mark.asyncio
  98. async def test_different_args_have_separate_cache_entries(self):
  99. call_count = 0
  100. test_cache = make_cache()
  101. class MyService:
  102. @locked_cached(cache=test_cache)
  103. async def fetch(self, name: str):
  104. nonlocal call_count
  105. call_count += 1
  106. return f"result-{name}"
  107. svc = MyService()
  108. await svc.fetch("foo")
  109. await svc.fetch("bar")
  110. assert call_count == 2
  111. @pytest.mark.asyncio
  112. async def test_none_result_is_not_cached(self):
  113. call_count = 0
  114. test_cache = make_cache()
  115. class MyService:
  116. @locked_cached(cache=test_cache)
  117. async def fetch(self, name: str):
  118. nonlocal call_count
  119. call_count += 1
  120. return None
  121. svc = MyService()
  122. await svc.fetch("foo")
  123. await svc.fetch("foo")
  124. assert call_count == 2
  125. @pytest.mark.asyncio
  126. async def test_positional_and_keyword_hit_same_cache_entry(self):
  127. """Regression: before the inspect.signature fix, keyword-arg calls generated
  128. a different cache key than positional-arg calls, so cache was never reused."""
  129. call_count = 0
  130. test_cache = make_cache()
  131. class MyService:
  132. @locked_cached(cache=test_cache)
  133. async def fetch(self, name: str):
  134. nonlocal call_count
  135. call_count += 1
  136. return f"result-{name}"
  137. svc = MyService()
  138. r1 = await svc.fetch("foo")
  139. r2 = await svc.fetch(name="foo")
  140. assert r1 == r2
  141. assert call_count == 1
  142. @pytest.mark.asyncio
  143. async def test_cache_key_matches_delete_cache_by_key(self):
  144. """Regression: deleting via positional arg must invalidate an entry that was
  145. populated via keyword arg (the bug in token.py before fix-5168)."""
  146. call_count = 0
  147. test_cache = make_cache()
  148. class MyService:
  149. @locked_cached(cache=test_cache)
  150. async def fetch(self, name: str):
  151. nonlocal call_count
  152. call_count += 1
  153. return f"result-{name}"
  154. svc = MyService()
  155. await svc.fetch(name="foo")
  156. assert call_count == 1
  157. # Simulate what services.py update()/delete() does: positional arg, bound method
  158. key = build_cache_key(svc.fetch, "foo")
  159. await test_cache.delete(key)
  160. await svc.fetch(name="foo")
  161. assert call_count == 2
  162. @pytest.mark.asyncio
  163. async def test_delete_cache_by_key_invalidates_entry(self):
  164. """delete_cache_by_key correctly evicts a cached result (uses global cache)."""
  165. call_count = 0
  166. class MyService:
  167. @locked_cached()
  168. async def fetch(self, name: str):
  169. nonlocal call_count
  170. call_count += 1
  171. return f"result-{name}"
  172. svc = MyService()
  173. await svc.fetch("foo")
  174. assert call_count == 1
  175. await delete_cache_by_key(svc.fetch, "foo")
  176. await svc.fetch("foo")
  177. assert call_count == 2
  178. @pytest.mark.asyncio
  179. async def test_delete_cache_by_key_positional_invalidates_keyword_call(self):
  180. """delete_cache_by_key with positional args invalidates entry cached via
  181. keyword args (the actual bug scenario in services.py update/delete)."""
  182. call_count = 0
  183. class MyService:
  184. @locked_cached()
  185. async def fetch(self, name: str):
  186. nonlocal call_count
  187. call_count += 1
  188. return f"result-{name}"
  189. svc = MyService()
  190. await svc.fetch(name="foo")
  191. assert call_count == 1
  192. await delete_cache_by_key(svc.fetch, "foo")
  193. await svc.fetch(name="foo")
  194. assert call_count == 2
  195. @pytest.mark.asyncio
  196. async def test_concurrent_calls_execute_function_once(self):
  197. call_count = 0
  198. test_cache = make_cache()
  199. class MyService:
  200. @locked_cached(cache=test_cache)
  201. async def fetch(self, name: str):
  202. nonlocal call_count
  203. call_count += 1
  204. await asyncio.sleep(0.05)
  205. return f"result-{name}"
  206. svc = MyService()
  207. results = await asyncio.gather(*[svc.fetch("foo") for _ in range(5)])
  208. assert all(r == "result-foo" for r in results)
  209. assert call_count == 1
  210. @pytest.mark.asyncio
  211. async def test_custom_static_key(self):
  212. call_count = 0
  213. test_cache = make_cache()
  214. class MyService:
  215. @locked_cached(cache=test_cache, key="fixed-key")
  216. async def fetch(self, name: str):
  217. nonlocal call_count
  218. call_count += 1
  219. return f"result-{name}"
  220. svc = MyService()
  221. await svc.fetch("foo")
  222. await svc.fetch("bar") # different arg, same fixed key → cache hit
  223. assert call_count == 1
  224. @pytest.mark.asyncio
  225. async def test_custom_callable_key(self):
  226. call_count = 0
  227. test_cache = make_cache()
  228. def my_key(f, *args, **kwargs):
  229. return f"custom:{args[1]}" # args[0] is self
  230. class MyService:
  231. @locked_cached(cache=test_cache, key=my_key)
  232. async def fetch(self, name: str):
  233. nonlocal call_count
  234. call_count += 1
  235. return f"result-{name}"
  236. svc = MyService()
  237. r1 = await svc.fetch("foo")
  238. r2 = await svc.fetch("foo")
  239. assert r1 == r2
  240. assert call_count == 1
  241. # ---------------------------------------------------------------------------
  242. # class_key helper
  243. # ---------------------------------------------------------------------------
  244. class TestClassKey:
  245. def test_key_format_is_classname_dot_suffix(self):
  246. async def dummy():
  247. pass
  248. kb = class_key("all_cached")
  249. class MyModel:
  250. pass
  251. assert kb(dummy, MyModel) == "MyModel.all_cached"
  252. def test_different_classes_produce_different_keys(self):
  253. async def dummy():
  254. pass
  255. kb = class_key("all_cached")
  256. class A:
  257. pass
  258. class B:
  259. pass
  260. assert kb(dummy, A) != kb(dummy, B)
  261. # ---------------------------------------------------------------------------
  262. # delete_cache_by_key
  263. # ---------------------------------------------------------------------------
  264. class TestDeleteCacheByKey:
  265. @pytest.mark.asyncio
  266. async def test_delete_by_explicit_key(self):
  267. await global_cache.set("my-key", "my-value")
  268. assert await global_cache.get("my-key") == "my-value"
  269. await delete_cache_by_key(_key="my-key")
  270. assert await global_cache.get("my-key") is None
  271. @pytest.mark.asyncio
  272. async def test_delete_nonexistent_key_is_safe(self):
  273. await delete_cache_by_key(_key="nonexistent-key")
  274. @pytest.mark.asyncio
  275. async def test_raises_if_neither_func_nor_key(self):
  276. with pytest.raises(ValueError):
  277. await delete_cache_by_key()
  278. @pytest.mark.asyncio
  279. async def test_delete_by_func_and_args(self):
  280. call_count = 0
  281. class MyService:
  282. @locked_cached()
  283. async def lookup(self, item_id: int):
  284. nonlocal call_count
  285. call_count += 1
  286. return f"item-{item_id}"
  287. svc = MyService()
  288. await svc.lookup(42)
  289. assert call_count == 1
  290. await delete_cache_by_key(svc.lookup, 42)
  291. await svc.lookup(42)
  292. assert call_count == 2
  293. @pytest.mark.asyncio
  294. async def test_delete_only_removes_matching_key(self):
  295. call_count = {"a": 0, "b": 0}
  296. class MyService:
  297. @locked_cached()
  298. async def lookup(self, name: str):
  299. call_count[name] += 1
  300. return f"result-{name}"
  301. svc = MyService()
  302. await svc.lookup("a")
  303. await svc.lookup("b")
  304. await delete_cache_by_key(svc.lookup, "a")
  305. await svc.lookup("a")
  306. await svc.lookup("b")
  307. assert call_count["a"] == 2
  308. assert call_count["b"] == 1