test_annotation_api.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. """
  2. Unit tests for Annotation API endpoints.
  3. Tests CRUD operations for annotations.
  4. """
  5. import pytest
  6. import os
  7. import json
  8. from fastapi.testclient import TestClient
  9. # Use a test database
  10. TEST_DB_PATH = "test_annotation_annotation_platform.db"
  11. @pytest.fixture(scope="function", autouse=True)
  12. def setup_test_db():
  13. """Setup test database before each test and cleanup after."""
  14. # Set test database path
  15. original_db_path = os.environ.get("DATABASE_PATH")
  16. os.environ["DATABASE_PATH"] = TEST_DB_PATH
  17. # Remove existing test database
  18. if os.path.exists(TEST_DB_PATH):
  19. os.remove(TEST_DB_PATH)
  20. # Import after setting env var
  21. from database import init_database
  22. init_database()
  23. yield
  24. # Cleanup
  25. if os.path.exists(TEST_DB_PATH):
  26. os.remove(TEST_DB_PATH)
  27. # Restore original path
  28. if original_db_path:
  29. os.environ["DATABASE_PATH"] = original_db_path
  30. elif "DATABASE_PATH" in os.environ:
  31. del os.environ["DATABASE_PATH"]
  32. @pytest.fixture(scope="function")
  33. def test_client():
  34. """Create a test client."""
  35. from main import app
  36. return TestClient(app)
  37. @pytest.fixture(scope="function")
  38. def sample_project(test_client):
  39. """Create a sample project for testing."""
  40. project_data = {
  41. "name": "Test Project",
  42. "description": "Test Description",
  43. "config": "<View><Image name='img' value='$image'/></View>"
  44. }
  45. response = test_client.post("/api/projects", json=project_data)
  46. return response.json()
  47. @pytest.fixture(scope="function")
  48. def sample_task(test_client, sample_project):
  49. """Create a sample task for testing."""
  50. task_data = {
  51. "project_id": sample_project["id"],
  52. "name": "Test Task",
  53. "data": {"image_url": "https://example.com/image.jpg"},
  54. "assigned_to": "user_001"
  55. }
  56. response = test_client.post("/api/tasks", json=task_data)
  57. return response.json()
  58. def test_list_annotations_empty(test_client):
  59. """Test listing annotations when database is empty."""
  60. response = test_client.get("/api/annotations")
  61. assert response.status_code == 200
  62. assert response.json() == []
  63. def test_create_annotation(test_client, sample_task):
  64. """Test creating a new annotation."""
  65. annotation_data = {
  66. "task_id": sample_task["id"],
  67. "user_id": "user_001",
  68. "result": {
  69. "annotations": [
  70. {
  71. "id": "ann_1",
  72. "type": "rectanglelabels",
  73. "value": {
  74. "x": 10,
  75. "y": 20,
  76. "width": 100,
  77. "height": 50,
  78. "rectanglelabels": ["Cat"]
  79. }
  80. }
  81. ]
  82. }
  83. }
  84. response = test_client.post("/api/annotations", json=annotation_data)
  85. assert response.status_code == 201
  86. data = response.json()
  87. assert data["task_id"] == annotation_data["task_id"]
  88. assert data["user_id"] == annotation_data["user_id"]
  89. assert data["result"] == annotation_data["result"]
  90. assert "id" in data
  91. assert data["id"].startswith("ann_")
  92. assert "created_at" in data
  93. assert "updated_at" in data
  94. def test_create_annotation_invalid_task(test_client):
  95. """Test creating an annotation with invalid task_id fails."""
  96. annotation_data = {
  97. "task_id": "nonexistent_task",
  98. "user_id": "user_001",
  99. "result": {"annotations": []}
  100. }
  101. response = test_client.post("/api/annotations", json=annotation_data)
  102. assert response.status_code == 404
  103. assert "not found" in response.json()["detail"].lower()
  104. def test_get_annotation(test_client, sample_task):
  105. """Test getting an annotation by ID."""
  106. # Create an annotation first
  107. annotation_data = {
  108. "task_id": sample_task["id"],
  109. "user_id": "user_001",
  110. "result": {"annotations": [{"id": "ann_1", "value": "test"}]}
  111. }
  112. create_response = test_client.post("/api/annotations", json=annotation_data)
  113. annotation_id = create_response.json()["id"]
  114. # Get the annotation
  115. response = test_client.get(f"/api/annotations/{annotation_id}")
  116. assert response.status_code == 200
  117. data = response.json()
  118. assert data["id"] == annotation_id
  119. assert data["task_id"] == annotation_data["task_id"]
  120. assert data["user_id"] == annotation_data["user_id"]
  121. def test_get_annotation_not_found(test_client):
  122. """Test getting a non-existent annotation returns 404."""
  123. response = test_client.get("/api/annotations/nonexistent_id")
  124. assert response.status_code == 404
  125. assert "not found" in response.json()["detail"].lower()
  126. def test_update_annotation(test_client, sample_task):
  127. """Test updating an annotation."""
  128. # Create an annotation first
  129. annotation_data = {
  130. "task_id": sample_task["id"],
  131. "user_id": "user_001",
  132. "result": {"annotations": [{"id": "ann_1", "value": "original"}]}
  133. }
  134. create_response = test_client.post("/api/annotations", json=annotation_data)
  135. annotation_id = create_response.json()["id"]
  136. # Update the annotation
  137. update_data = {
  138. "result": {"annotations": [{"id": "ann_1", "value": "updated"}]}
  139. }
  140. response = test_client.put(f"/api/annotations/{annotation_id}", json=update_data)
  141. assert response.status_code == 200
  142. data = response.json()
  143. assert data["result"] == update_data["result"]
  144. assert data["task_id"] == annotation_data["task_id"] # Task ID unchanged
  145. assert data["user_id"] == annotation_data["user_id"] # User ID unchanged
  146. def test_update_annotation_not_found(test_client):
  147. """Test updating a non-existent annotation returns 404."""
  148. update_data = {"result": {"annotations": []}}
  149. response = test_client.put("/api/annotations/nonexistent_id", json=update_data)
  150. assert response.status_code == 404
  151. def test_list_annotations_after_creation(test_client, sample_task):
  152. """Test listing annotations after creating some."""
  153. # Create multiple annotations
  154. for i in range(3):
  155. annotation_data = {
  156. "task_id": sample_task["id"],
  157. "user_id": f"user_{i:03d}",
  158. "result": {"annotations": [{"id": f"ann_{i}", "value": f"test_{i}"}]}
  159. }
  160. test_client.post("/api/annotations", json=annotation_data)
  161. # List annotations
  162. response = test_client.get("/api/annotations")
  163. assert response.status_code == 200
  164. data = response.json()
  165. assert len(data) == 3
  166. assert all("id" in annotation for annotation in data)
  167. assert all("created_at" in annotation for annotation in data)
  168. assert all("updated_at" in annotation for annotation in data)
  169. def test_list_annotations_filter_by_task(test_client, sample_project):
  170. """Test filtering annotations by task_id."""
  171. # Create two tasks
  172. task1_data = {
  173. "project_id": sample_project["id"],
  174. "name": "Task 1",
  175. "data": {"image_url": "https://example.com/image1.jpg"}
  176. }
  177. task1_response = test_client.post("/api/tasks", json=task1_data)
  178. task1 = task1_response.json()
  179. task2_data = {
  180. "project_id": sample_project["id"],
  181. "name": "Task 2",
  182. "data": {"image_url": "https://example.com/image2.jpg"}
  183. }
  184. task2_response = test_client.post("/api/tasks", json=task2_data)
  185. task2 = task2_response.json()
  186. # Create annotations for both tasks
  187. ann1_data = {
  188. "task_id": task1["id"],
  189. "user_id": "user_001",
  190. "result": {"annotations": [{"id": "ann_1"}]}
  191. }
  192. test_client.post("/api/annotations", json=ann1_data)
  193. ann2_data = {
  194. "task_id": task2["id"],
  195. "user_id": "user_001",
  196. "result": {"annotations": [{"id": "ann_2"}]}
  197. }
  198. test_client.post("/api/annotations", json=ann2_data)
  199. # Filter by first task
  200. response = test_client.get(f"/api/annotations?task_id={task1['id']}")
  201. assert response.status_code == 200
  202. data = response.json()
  203. assert len(data) == 1
  204. assert data[0]["task_id"] == task1["id"]
  205. def test_list_annotations_filter_by_user(test_client, sample_task):
  206. """Test filtering annotations by user_id."""
  207. # Create annotations for different users
  208. ann1_data = {
  209. "task_id": sample_task["id"],
  210. "user_id": "user_001",
  211. "result": {"annotations": [{"id": "ann_1"}]}
  212. }
  213. test_client.post("/api/annotations", json=ann1_data)
  214. ann2_data = {
  215. "task_id": sample_task["id"],
  216. "user_id": "user_002",
  217. "result": {"annotations": [{"id": "ann_2"}]}
  218. }
  219. test_client.post("/api/annotations", json=ann2_data)
  220. # Filter by first user
  221. response = test_client.get("/api/annotations?user_id=user_001")
  222. assert response.status_code == 200
  223. data = response.json()
  224. assert len(data) == 1
  225. assert data[0]["user_id"] == "user_001"
  226. def test_get_task_annotations(test_client, sample_task):
  227. """Test getting all annotations for a specific task."""
  228. # Create annotations for the task
  229. for i in range(2):
  230. annotation_data = {
  231. "task_id": sample_task["id"],
  232. "user_id": f"user_{i:03d}",
  233. "result": {"annotations": [{"id": f"ann_{i}"}]}
  234. }
  235. test_client.post("/api/annotations", json=annotation_data)
  236. # Get task annotations using the alternative endpoint
  237. response = test_client.get(f"/api/annotations/tasks/{sample_task['id']}/annotations")
  238. assert response.status_code == 200
  239. data = response.json()
  240. assert len(data) == 2
  241. assert all(annotation["task_id"] == sample_task["id"] for annotation in data)
  242. def test_get_task_annotations_not_found(test_client):
  243. """Test getting annotations for non-existent task returns 404."""
  244. response = test_client.get("/api/annotations/tasks/nonexistent_id/annotations")
  245. assert response.status_code == 404
  246. def test_annotation_json_serialization(test_client, sample_task):
  247. """Test that complex JSON data is properly serialized and deserialized."""
  248. complex_result = {
  249. "annotations": [
  250. {
  251. "id": "ann_1",
  252. "type": "rectanglelabels",
  253. "value": {
  254. "x": 10.5,
  255. "y": 20.3,
  256. "width": 100,
  257. "height": 50,
  258. "rectanglelabels": ["Cat", "Animal"]
  259. }
  260. },
  261. {
  262. "id": "ann_2",
  263. "type": "choices",
  264. "value": {
  265. "choices": ["Option A"]
  266. }
  267. }
  268. ],
  269. "metadata": {
  270. "duration": 120,
  271. "quality": "high"
  272. }
  273. }
  274. annotation_data = {
  275. "task_id": sample_task["id"],
  276. "user_id": "user_001",
  277. "result": complex_result
  278. }
  279. # Create annotation
  280. create_response = test_client.post("/api/annotations", json=annotation_data)
  281. assert create_response.status_code == 201
  282. annotation_id = create_response.json()["id"]
  283. # Get annotation and verify data integrity
  284. get_response = test_client.get(f"/api/annotations/{annotation_id}")
  285. assert get_response.status_code == 200
  286. data = get_response.json()
  287. assert data["result"] == complex_result
  288. assert data["result"]["annotations"][0]["value"]["x"] == 10.5
  289. assert data["result"]["metadata"]["duration"] == 120