test_task_api.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. """
  2. Unit tests for Task API endpoints.
  3. Tests CRUD operations for tasks.
  4. """
  5. import pytest
  6. import os
  7. import json
  8. from fastapi.testclient import TestClient
  9. # Use a test database
  10. TEST_DB_PATH = "test_task_annotation_platform.db"
  11. @pytest.fixture(scope="function", autouse=True)
  12. def setup_test_db():
  13. """Setup test database before each test and cleanup after."""
  14. # Set test database path
  15. original_db_path = os.environ.get("DATABASE_PATH")
  16. os.environ["DATABASE_PATH"] = TEST_DB_PATH
  17. # Remove existing test database
  18. if os.path.exists(TEST_DB_PATH):
  19. os.remove(TEST_DB_PATH)
  20. # Import after setting env var
  21. from database import init_database
  22. init_database()
  23. yield
  24. # Cleanup
  25. if os.path.exists(TEST_DB_PATH):
  26. os.remove(TEST_DB_PATH)
  27. # Restore original path
  28. if original_db_path:
  29. os.environ["DATABASE_PATH"] = original_db_path
  30. elif "DATABASE_PATH" in os.environ:
  31. del os.environ["DATABASE_PATH"]
  32. @pytest.fixture(scope="function")
  33. def test_client():
  34. """Create a test client."""
  35. from main import app
  36. return TestClient(app)
  37. @pytest.fixture(scope="function")
  38. def sample_project(test_client):
  39. """Create a sample project for testing."""
  40. project_data = {
  41. "name": "Test Project",
  42. "description": "Test Description",
  43. "config": "<View><Image name='img' value='$image'/></View>"
  44. }
  45. response = test_client.post("/api/projects", json=project_data)
  46. return response.json()
  47. def test_list_tasks_empty(test_client):
  48. """Test listing tasks when database is empty."""
  49. response = test_client.get("/api/tasks")
  50. assert response.status_code == 200
  51. assert response.json() == []
  52. def test_create_task(test_client, sample_project):
  53. """Test creating a new task."""
  54. task_data = {
  55. "project_id": sample_project["id"],
  56. "name": "Test Task",
  57. "data": {"image_url": "https://example.com/image.jpg"},
  58. "assigned_to": "user_001"
  59. }
  60. response = test_client.post("/api/tasks", json=task_data)
  61. assert response.status_code == 201
  62. data = response.json()
  63. assert data["name"] == task_data["name"]
  64. assert data["project_id"] == task_data["project_id"]
  65. assert data["data"] == task_data["data"]
  66. assert data["assigned_to"] == task_data["assigned_to"]
  67. assert data["status"] == "pending"
  68. assert "id" in data
  69. assert data["id"].startswith("task_")
  70. assert data["progress"] == 0.0
  71. assert "created_at" in data
  72. def test_create_task_invalid_project(test_client):
  73. """Test creating a task with invalid project_id fails."""
  74. task_data = {
  75. "project_id": "nonexistent_project",
  76. "name": "Test Task",
  77. "data": {"image_url": "https://example.com/image.jpg"}
  78. }
  79. response = test_client.post("/api/tasks", json=task_data)
  80. assert response.status_code == 404
  81. assert "not found" in response.json()["detail"].lower()
  82. def test_get_task(test_client, sample_project):
  83. """Test getting a task by ID."""
  84. # Create a task first
  85. task_data = {
  86. "project_id": sample_project["id"],
  87. "name": "Test Task",
  88. "data": {"image_url": "https://example.com/image.jpg"}
  89. }
  90. create_response = test_client.post("/api/tasks", json=task_data)
  91. task_id = create_response.json()["id"]
  92. # Get the task
  93. response = test_client.get(f"/api/tasks/{task_id}")
  94. assert response.status_code == 200
  95. data = response.json()
  96. assert data["id"] == task_id
  97. assert data["name"] == task_data["name"]
  98. def test_get_task_not_found(test_client):
  99. """Test getting a non-existent task returns 404."""
  100. response = test_client.get("/api/tasks/nonexistent_id")
  101. assert response.status_code == 404
  102. assert "not found" in response.json()["detail"].lower()
  103. def test_update_task(test_client, sample_project):
  104. """Test updating a task."""
  105. # Create a task first
  106. task_data = {
  107. "project_id": sample_project["id"],
  108. "name": "Original Name",
  109. "data": {"image_url": "https://example.com/image.jpg"}
  110. }
  111. create_response = test_client.post("/api/tasks", json=task_data)
  112. task_id = create_response.json()["id"]
  113. # Update the task
  114. update_data = {
  115. "name": "Updated Name",
  116. "status": "in_progress"
  117. }
  118. response = test_client.put(f"/api/tasks/{task_id}", json=update_data)
  119. assert response.status_code == 200
  120. data = response.json()
  121. assert data["name"] == update_data["name"]
  122. assert data["status"] == update_data["status"]
  123. assert data["data"] == task_data["data"] # Data unchanged
  124. def test_update_task_not_found(test_client):
  125. """Test updating a non-existent task returns 404."""
  126. update_data = {"name": "Updated Name"}
  127. response = test_client.put("/api/tasks/nonexistent_id", json=update_data)
  128. assert response.status_code == 404
  129. def test_delete_task(test_client, sample_project):
  130. """Test deleting a task."""
  131. # Create a task first
  132. task_data = {
  133. "project_id": sample_project["id"],
  134. "name": "Test Task",
  135. "data": {"image_url": "https://example.com/image.jpg"}
  136. }
  137. create_response = test_client.post("/api/tasks", json=task_data)
  138. task_id = create_response.json()["id"]
  139. # Delete the task
  140. response = test_client.delete(f"/api/tasks/{task_id}")
  141. assert response.status_code == 204
  142. # Verify task is deleted
  143. get_response = test_client.get(f"/api/tasks/{task_id}")
  144. assert get_response.status_code == 404
  145. def test_delete_task_not_found(test_client):
  146. """Test deleting a non-existent task returns 404."""
  147. response = test_client.delete("/api/tasks/nonexistent_id")
  148. assert response.status_code == 404
  149. def test_list_tasks_after_creation(test_client, sample_project):
  150. """Test listing tasks after creating some."""
  151. # Create multiple tasks
  152. for i in range(3):
  153. task_data = {
  154. "project_id": sample_project["id"],
  155. "name": f"Task {i}",
  156. "data": {"image_url": f"https://example.com/image{i}.jpg"}
  157. }
  158. test_client.post("/api/tasks", json=task_data)
  159. # List tasks
  160. response = test_client.get("/api/tasks")
  161. assert response.status_code == 200
  162. data = response.json()
  163. assert len(data) == 3
  164. assert all("id" in task for task in data)
  165. assert all("progress" in task for task in data)
  166. def test_list_tasks_filter_by_project(test_client, sample_project):
  167. """Test filtering tasks by project_id."""
  168. # Create another project
  169. project2_data = {
  170. "name": "Project 2",
  171. "description": "Description 2",
  172. "config": "<View></View>"
  173. }
  174. project2_response = test_client.post("/api/projects", json=project2_data)
  175. project2 = project2_response.json()
  176. # Create tasks for both projects
  177. task1_data = {
  178. "project_id": sample_project["id"],
  179. "name": "Task 1",
  180. "data": {"image_url": "https://example.com/image1.jpg"}
  181. }
  182. test_client.post("/api/tasks", json=task1_data)
  183. task2_data = {
  184. "project_id": project2["id"],
  185. "name": "Task 2",
  186. "data": {"image_url": "https://example.com/image2.jpg"}
  187. }
  188. test_client.post("/api/tasks", json=task2_data)
  189. # Filter by first project
  190. response = test_client.get(f"/api/tasks?project_id={sample_project['id']}")
  191. assert response.status_code == 200
  192. data = response.json()
  193. assert len(data) == 1
  194. assert data[0]["project_id"] == sample_project["id"]
  195. def test_list_tasks_filter_by_status(test_client, sample_project):
  196. """Test filtering tasks by status."""
  197. # Create tasks with different statuses
  198. task1_data = {
  199. "project_id": sample_project["id"],
  200. "name": "Task 1",
  201. "data": {"image_url": "https://example.com/image1.jpg"}
  202. }
  203. response1 = test_client.post("/api/tasks", json=task1_data)
  204. task1_id = response1.json()["id"]
  205. task2_data = {
  206. "project_id": sample_project["id"],
  207. "name": "Task 2",
  208. "data": {"image_url": "https://example.com/image2.jpg"}
  209. }
  210. test_client.post("/api/tasks", json=task2_data)
  211. # Update first task status
  212. test_client.put(f"/api/tasks/{task1_id}", json={"status": "in_progress"})
  213. # Filter by status
  214. response = test_client.get("/api/tasks?status=in_progress")
  215. assert response.status_code == 200
  216. data = response.json()
  217. assert len(data) == 1
  218. assert data[0]["status"] == "in_progress"
  219. def test_get_project_tasks(test_client, sample_project):
  220. """Test getting all tasks for a specific project."""
  221. # Create tasks for the project
  222. for i in range(2):
  223. task_data = {
  224. "project_id": sample_project["id"],
  225. "name": f"Task {i}",
  226. "data": {"image_url": f"https://example.com/image{i}.jpg"}
  227. }
  228. test_client.post("/api/tasks", json=task_data)
  229. # Get project tasks using the alternative endpoint
  230. response = test_client.get(f"/api/tasks/projects/{sample_project['id']}/tasks")
  231. assert response.status_code == 200
  232. data = response.json()
  233. assert len(data) == 2
  234. assert all(task["project_id"] == sample_project["id"] for task in data)
  235. def test_get_project_tasks_not_found(test_client):
  236. """Test getting tasks for non-existent project returns 404."""
  237. response = test_client.get("/api/tasks/projects/nonexistent_id/tasks")
  238. assert response.status_code == 404