test_database.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476
  1. """
  2. Unit tests for database initialization and connection management.
  3. Tests database table creation and connection handling.
  4. Requirements: 9.1, 9.2
  5. """
  6. import os
  7. import sqlite3
  8. import tempfile
  9. import pytest
  10. from contextlib import contextmanager
  11. # Import database functions
  12. import sys
  13. sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
  14. from database import init_database, get_db_connection, get_db, DB_PATH
  15. @pytest.fixture
  16. def temp_db():
  17. """
  18. Fixture that creates a temporary database for testing.
  19. Ensures each test runs with a clean database.
  20. """
  21. # Create a temporary database file
  22. fd, temp_db_path = tempfile.mkstemp(suffix='.db')
  23. os.close(fd)
  24. # Set the database path for testing
  25. original_db_path = os.environ.get('DATABASE_PATH')
  26. os.environ['DATABASE_PATH'] = temp_db_path
  27. # Force reload of database module to pick up new path
  28. import database
  29. database.DB_PATH = temp_db_path
  30. yield temp_db_path
  31. # Cleanup: restore original path and remove temp file
  32. if original_db_path:
  33. os.environ['DATABASE_PATH'] = original_db_path
  34. else:
  35. os.environ.pop('DATABASE_PATH', None)
  36. database.DB_PATH = original_db_path or "annotation_platform.db"
  37. try:
  38. os.unlink(temp_db_path)
  39. except OSError:
  40. pass
  41. class TestDatabaseInitialization:
  42. """Test suite for database initialization functionality."""
  43. def test_init_database_creates_tables(self, temp_db):
  44. """
  45. Test that init_database creates all required tables.
  46. Validates: Requirements 9.1
  47. """
  48. # Initialize the database
  49. init_database()
  50. # Connect to the database and check tables exist
  51. conn = sqlite3.connect(temp_db)
  52. cursor = conn.cursor()
  53. # Query for all tables
  54. cursor.execute("""
  55. SELECT name FROM sqlite_master
  56. WHERE type='table'
  57. ORDER BY name
  58. """)
  59. tables = [row[0] for row in cursor.fetchall()]
  60. # Verify all required tables exist
  61. assert 'projects' in tables, "projects table should exist"
  62. assert 'tasks' in tables, "tasks table should exist"
  63. assert 'annotations' in tables, "annotations table should exist"
  64. conn.close()
  65. def test_projects_table_schema(self, temp_db):
  66. """
  67. Test that projects table has correct schema.
  68. Validates: Requirements 9.1
  69. """
  70. init_database()
  71. conn = sqlite3.connect(temp_db)
  72. cursor = conn.cursor()
  73. # Get table schema
  74. cursor.execute("PRAGMA table_info(projects)")
  75. columns = {row[1]: row[2] for row in cursor.fetchall()}
  76. # Verify required columns exist with correct types
  77. assert 'id' in columns, "projects should have id column"
  78. assert 'name' in columns, "projects should have name column"
  79. assert 'description' in columns, "projects should have description column"
  80. assert 'config' in columns, "projects should have config column"
  81. assert 'created_at' in columns, "projects should have created_at column"
  82. # Verify id is primary key
  83. cursor.execute("""
  84. SELECT sql FROM sqlite_master
  85. WHERE type='table' AND name='projects'
  86. """)
  87. schema = cursor.fetchone()[0]
  88. assert 'PRIMARY KEY' in schema, "id should be primary key"
  89. conn.close()
  90. def test_tasks_table_schema(self, temp_db):
  91. """
  92. Test that tasks table has correct schema.
  93. Validates: Requirements 9.1
  94. """
  95. init_database()
  96. conn = sqlite3.connect(temp_db)
  97. cursor = conn.cursor()
  98. # Get table schema
  99. cursor.execute("PRAGMA table_info(tasks)")
  100. columns = {row[1]: row[2] for row in cursor.fetchall()}
  101. # Verify required columns exist
  102. assert 'id' in columns, "tasks should have id column"
  103. assert 'project_id' in columns, "tasks should have project_id column"
  104. assert 'name' in columns, "tasks should have name column"
  105. assert 'data' in columns, "tasks should have data column"
  106. assert 'status' in columns, "tasks should have status column"
  107. assert 'assigned_to' in columns, "tasks should have assigned_to column"
  108. assert 'created_at' in columns, "tasks should have created_at column"
  109. # Verify foreign key constraint
  110. cursor.execute("PRAGMA foreign_key_list(tasks)")
  111. foreign_keys = cursor.fetchall()
  112. assert len(foreign_keys) > 0, "tasks should have foreign key constraint"
  113. assert foreign_keys[0][2] == 'projects', "foreign key should reference projects"
  114. conn.close()
  115. def test_annotations_table_schema(self, temp_db):
  116. """
  117. Test that annotations table has correct schema.
  118. Validates: Requirements 9.1
  119. """
  120. init_database()
  121. conn = sqlite3.connect(temp_db)
  122. cursor = conn.cursor()
  123. # Get table schema
  124. cursor.execute("PRAGMA table_info(annotations)")
  125. columns = {row[1]: row[2] for row in cursor.fetchall()}
  126. # Verify required columns exist
  127. assert 'id' in columns, "annotations should have id column"
  128. assert 'task_id' in columns, "annotations should have task_id column"
  129. assert 'user_id' in columns, "annotations should have user_id column"
  130. assert 'result' in columns, "annotations should have result column"
  131. assert 'created_at' in columns, "annotations should have created_at column"
  132. assert 'updated_at' in columns, "annotations should have updated_at column"
  133. # Verify foreign key constraint
  134. cursor.execute("PRAGMA foreign_key_list(annotations)")
  135. foreign_keys = cursor.fetchall()
  136. assert len(foreign_keys) > 0, "annotations should have foreign key constraint"
  137. assert foreign_keys[0][2] == 'tasks', "foreign key should reference tasks"
  138. conn.close()
  139. def test_foreign_key_constraints_enabled(self, temp_db):
  140. """
  141. Test that foreign key constraints are enabled.
  142. Validates: Requirements 9.1
  143. """
  144. init_database()
  145. conn = sqlite3.connect(temp_db)
  146. cursor = conn.cursor()
  147. # Check if foreign keys are enabled
  148. cursor.execute("PRAGMA foreign_keys")
  149. result = cursor.fetchone()
  150. # Note: Foreign keys need to be enabled per connection
  151. # The init_database function enables them, but we need to enable for this connection too
  152. cursor.execute("PRAGMA foreign_keys = ON")
  153. cursor.execute("PRAGMA foreign_keys")
  154. result = cursor.fetchone()
  155. assert result[0] == 1, "foreign key constraints should be enabled"
  156. conn.close()
  157. def test_init_database_idempotent(self, temp_db):
  158. """
  159. Test that calling init_database multiple times is safe.
  160. Validates: Requirements 9.1
  161. """
  162. # Initialize database multiple times
  163. init_database()
  164. init_database()
  165. init_database()
  166. # Verify tables still exist and are not duplicated
  167. conn = sqlite3.connect(temp_db)
  168. cursor = conn.cursor()
  169. cursor.execute("""
  170. SELECT name FROM sqlite_master
  171. WHERE type='table'
  172. ORDER BY name
  173. """)
  174. tables = [row[0] for row in cursor.fetchall()]
  175. # Should have exactly 3 tables
  176. assert len(tables) == 3, "should have exactly 3 tables"
  177. assert 'projects' in tables
  178. assert 'tasks' in tables
  179. assert 'annotations' in tables
  180. conn.close()
  181. class TestConnectionManagement:
  182. """Test suite for database connection management."""
  183. def test_get_db_connection_context_manager(self, temp_db):
  184. """
  185. Test that get_db_connection works as context manager.
  186. Validates: Requirements 9.2
  187. """
  188. init_database()
  189. # Use context manager
  190. with get_db_connection() as conn:
  191. assert conn is not None, "connection should not be None"
  192. assert isinstance(conn, sqlite3.Connection), "should return Connection object"
  193. # Verify we can execute queries
  194. cursor = conn.cursor()
  195. cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
  196. tables = cursor.fetchall()
  197. assert len(tables) > 0, "should be able to query tables"
  198. def test_get_db_connection_commits_on_success(self, temp_db):
  199. """
  200. Test that get_db_connection commits changes on success.
  201. Validates: Requirements 9.2
  202. """
  203. init_database()
  204. # Insert data using context manager
  205. with get_db_connection() as conn:
  206. cursor = conn.cursor()
  207. cursor.execute("""
  208. INSERT INTO projects (id, name, description, config)
  209. VALUES ('test-1', 'Test Project', 'Description', '{}')
  210. """)
  211. # Verify data was committed
  212. conn = sqlite3.connect(temp_db)
  213. cursor = conn.cursor()
  214. cursor.execute("SELECT * FROM projects WHERE id='test-1'")
  215. result = cursor.fetchone()
  216. assert result is not None, "data should be committed"
  217. conn.close()
  218. def test_get_db_connection_rolls_back_on_error(self, temp_db):
  219. """
  220. Test that get_db_connection rolls back on error.
  221. Validates: Requirements 9.2
  222. """
  223. init_database()
  224. # Try to insert data and raise an error
  225. try:
  226. with get_db_connection() as conn:
  227. cursor = conn.cursor()
  228. cursor.execute("""
  229. INSERT INTO projects (id, name, description, config)
  230. VALUES ('test-2', 'Test Project', 'Description', '{}')
  231. """)
  232. # Raise an error before commit
  233. raise ValueError("Test error")
  234. except ValueError:
  235. pass
  236. # Verify data was rolled back
  237. conn = sqlite3.connect(temp_db)
  238. cursor = conn.cursor()
  239. cursor.execute("SELECT * FROM projects WHERE id='test-2'")
  240. result = cursor.fetchone()
  241. assert result is None, "data should be rolled back on error"
  242. conn.close()
  243. def test_get_db_connection_closes_connection(self, temp_db):
  244. """
  245. Test that get_db_connection closes connection after use.
  246. Validates: Requirements 9.2
  247. """
  248. init_database()
  249. conn_ref = None
  250. with get_db_connection() as conn:
  251. conn_ref = conn
  252. assert conn is not None
  253. # Try to use connection after context manager exits
  254. # This should fail because connection is closed
  255. with pytest.raises(sqlite3.ProgrammingError):
  256. cursor = conn_ref.cursor()
  257. cursor.execute("SELECT 1")
  258. def test_get_db_connection_enables_row_factory(self, temp_db):
  259. """
  260. Test that get_db_connection enables row factory for column access.
  261. Validates: Requirements 9.2
  262. """
  263. init_database()
  264. # Insert test data
  265. with get_db_connection() as conn:
  266. cursor = conn.cursor()
  267. cursor.execute("""
  268. INSERT INTO projects (id, name, description, config)
  269. VALUES ('test-3', 'Test Project', 'Description', '{}')
  270. """)
  271. # Query and verify row factory works
  272. with get_db_connection() as conn:
  273. cursor = conn.cursor()
  274. cursor.execute("SELECT * FROM projects WHERE id='test-3'")
  275. row = cursor.fetchone()
  276. # Should be able to access columns by name
  277. assert row['id'] == 'test-3', "should access column by name"
  278. assert row['name'] == 'Test Project', "should access column by name"
  279. def test_get_db_function(self, temp_db):
  280. """
  281. Test that get_db function returns a valid connection.
  282. Validates: Requirements 9.2
  283. """
  284. init_database()
  285. # Get connection
  286. conn = get_db()
  287. try:
  288. assert conn is not None, "connection should not be None"
  289. assert isinstance(conn, sqlite3.Connection), "should return Connection object"
  290. # Verify row factory is enabled
  291. cursor = conn.cursor()
  292. cursor.execute("""
  293. INSERT INTO projects (id, name, description, config)
  294. VALUES ('test-4', 'Test Project', 'Description', '{}')
  295. """)
  296. conn.commit()
  297. cursor.execute("SELECT * FROM projects WHERE id='test-4'")
  298. row = cursor.fetchone()
  299. assert row['id'] == 'test-4', "row factory should be enabled"
  300. finally:
  301. # Caller is responsible for closing
  302. conn.close()
  303. def test_get_db_enables_foreign_keys(self, temp_db):
  304. """
  305. Test that get_db enables foreign key constraints.
  306. Validates: Requirements 9.2
  307. """
  308. init_database()
  309. conn = get_db()
  310. try:
  311. cursor = conn.cursor()
  312. cursor.execute("PRAGMA foreign_keys")
  313. result = cursor.fetchone()
  314. assert result[0] == 1, "foreign keys should be enabled"
  315. finally:
  316. conn.close()
  317. class TestDatabaseIntegrity:
  318. """Test suite for database integrity and constraints."""
  319. def test_cascade_delete_tasks_on_project_delete(self, temp_db):
  320. """
  321. Test that deleting a project cascades to delete tasks.
  322. Validates: Requirements 9.1
  323. """
  324. init_database()
  325. with get_db_connection() as conn:
  326. cursor = conn.cursor()
  327. # Insert project
  328. cursor.execute("""
  329. INSERT INTO projects (id, name, description, config)
  330. VALUES ('proj-1', 'Project 1', 'Description', '{}')
  331. """)
  332. # Insert task
  333. cursor.execute("""
  334. INSERT INTO tasks (id, project_id, name, data, status)
  335. VALUES ('task-1', 'proj-1', 'Task 1', '{}', 'pending')
  336. """)
  337. # Verify task exists
  338. with get_db_connection() as conn:
  339. cursor = conn.cursor()
  340. cursor.execute("SELECT * FROM tasks WHERE id='task-1'")
  341. assert cursor.fetchone() is not None, "task should exist"
  342. # Delete project
  343. with get_db_connection() as conn:
  344. cursor = conn.cursor()
  345. cursor.execute("DELETE FROM projects WHERE id='proj-1'")
  346. # Verify task was deleted
  347. with get_db_connection() as conn:
  348. cursor = conn.cursor()
  349. cursor.execute("SELECT * FROM tasks WHERE id='task-1'")
  350. assert cursor.fetchone() is None, "task should be deleted with project"
  351. def test_cascade_delete_annotations_on_task_delete(self, temp_db):
  352. """
  353. Test that deleting a task cascades to delete annotations.
  354. Validates: Requirements 9.1
  355. """
  356. init_database()
  357. with get_db_connection() as conn:
  358. cursor = conn.cursor()
  359. # Insert project
  360. cursor.execute("""
  361. INSERT INTO projects (id, name, description, config)
  362. VALUES ('proj-2', 'Project 2', 'Description', '{}')
  363. """)
  364. # Insert task
  365. cursor.execute("""
  366. INSERT INTO tasks (id, project_id, name, data, status)
  367. VALUES ('task-2', 'proj-2', 'Task 2', '{}', 'pending')
  368. """)
  369. # Insert annotation
  370. cursor.execute("""
  371. INSERT INTO annotations (id, task_id, user_id, result)
  372. VALUES ('ann-1', 'task-2', 'user-1', '{}')
  373. """)
  374. # Verify annotation exists
  375. with get_db_connection() as conn:
  376. cursor = conn.cursor()
  377. cursor.execute("SELECT * FROM annotations WHERE id='ann-1'")
  378. assert cursor.fetchone() is not None, "annotation should exist"
  379. # Delete task
  380. with get_db_connection() as conn:
  381. cursor = conn.cursor()
  382. cursor.execute("DELETE FROM tasks WHERE id='task-2'")
  383. # Verify annotation was deleted
  384. with get_db_connection() as conn:
  385. cursor = conn.cursor()
  386. cursor.execute("SELECT * FROM annotations WHERE id='ann-1'")
  387. assert cursor.fetchone() is None, "annotation should be deleted with task"