test_database.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. """
  2. Unit tests for database initialization and connection management.
  3. Tests database table creation and connection handling.
  4. Requirements: 9.1, 9.2
  5. """
  6. import os
  7. import sqlite3
  8. import tempfile
  9. import pytest
  10. from contextlib import contextmanager
  11. # Import database functions
  12. import sys
  13. sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
  14. from database import init_database, get_db_connection, get_db, db_config
  15. @pytest.fixture
  16. def temp_db():
  17. """
  18. Fixture that creates a temporary database for testing.
  19. Ensures each test runs with a clean database.
  20. """
  21. # Create a temporary database file
  22. fd, temp_db_path = tempfile.mkstemp(suffix='.db')
  23. os.close(fd)
  24. # Save original path
  25. original_db_path = db_config.db_path
  26. # Set the database path for testing
  27. db_config.db_path = temp_db_path
  28. yield temp_db_path
  29. # Cleanup: restore original path and remove temp file
  30. db_config.db_path = original_db_path
  31. try:
  32. os.unlink(temp_db_path)
  33. except OSError:
  34. pass
  35. class TestDatabaseInitialization:
  36. """Test suite for database initialization functionality."""
  37. def test_init_database_creates_tables(self, temp_db):
  38. """
  39. Test that init_database creates all required tables.
  40. Validates: Requirements 9.1
  41. """
  42. # Initialize the database
  43. init_database()
  44. # Connect to the database and check tables exist
  45. conn = sqlite3.connect(temp_db)
  46. cursor = conn.cursor()
  47. # Query for all tables
  48. cursor.execute("""
  49. SELECT name FROM sqlite_master
  50. WHERE type='table'
  51. ORDER BY name
  52. """)
  53. tables = [row[0] for row in cursor.fetchall()]
  54. # Verify all required tables exist
  55. assert 'projects' in tables, "projects table should exist"
  56. assert 'tasks' in tables, "tasks table should exist"
  57. assert 'annotations' in tables, "annotations table should exist"
  58. assert 'users' in tables, "users table should exist"
  59. conn.close()
  60. def test_projects_table_schema(self, temp_db):
  61. """
  62. Test that projects table has correct schema.
  63. Validates: Requirements 9.1
  64. """
  65. init_database()
  66. conn = sqlite3.connect(temp_db)
  67. cursor = conn.cursor()
  68. # Get table schema
  69. cursor.execute("PRAGMA table_info(projects)")
  70. columns = {row[1]: row[2] for row in cursor.fetchall()}
  71. # Verify required columns exist with correct types
  72. assert 'id' in columns, "projects should have id column"
  73. assert 'name' in columns, "projects should have name column"
  74. assert 'description' in columns, "projects should have description column"
  75. assert 'config' in columns, "projects should have config column"
  76. assert 'created_at' in columns, "projects should have created_at column"
  77. # Verify id is primary key
  78. cursor.execute("""
  79. SELECT sql FROM sqlite_master
  80. WHERE type='table' AND name='projects'
  81. """)
  82. schema = cursor.fetchone()[0]
  83. assert 'PRIMARY KEY' in schema, "id should be primary key"
  84. conn.close()
  85. def test_tasks_table_schema(self, temp_db):
  86. """
  87. Test that tasks table has correct schema.
  88. Validates: Requirements 9.1
  89. """
  90. init_database()
  91. conn = sqlite3.connect(temp_db)
  92. cursor = conn.cursor()
  93. # Get table schema
  94. cursor.execute("PRAGMA table_info(tasks)")
  95. columns = {row[1]: row[2] for row in cursor.fetchall()}
  96. # Verify required columns exist
  97. assert 'id' in columns, "tasks should have id column"
  98. assert 'project_id' in columns, "tasks should have project_id column"
  99. assert 'name' in columns, "tasks should have name column"
  100. assert 'data' in columns, "tasks should have data column"
  101. assert 'status' in columns, "tasks should have status column"
  102. assert 'assigned_to' in columns, "tasks should have assigned_to column"
  103. assert 'created_at' in columns, "tasks should have created_at column"
  104. # Verify foreign key constraint
  105. cursor.execute("PRAGMA foreign_key_list(tasks)")
  106. foreign_keys = cursor.fetchall()
  107. assert len(foreign_keys) > 0, "tasks should have foreign key constraint"
  108. assert foreign_keys[0][2] == 'projects', "foreign key should reference projects"
  109. conn.close()
  110. def test_annotations_table_schema(self, temp_db):
  111. """
  112. Test that annotations table has correct schema.
  113. Validates: Requirements 9.1
  114. """
  115. init_database()
  116. conn = sqlite3.connect(temp_db)
  117. cursor = conn.cursor()
  118. # Get table schema
  119. cursor.execute("PRAGMA table_info(annotations)")
  120. columns = {row[1]: row[2] for row in cursor.fetchall()}
  121. # Verify required columns exist
  122. assert 'id' in columns, "annotations should have id column"
  123. assert 'task_id' in columns, "annotations should have task_id column"
  124. assert 'user_id' in columns, "annotations should have user_id column"
  125. assert 'result' in columns, "annotations should have result column"
  126. assert 'created_at' in columns, "annotations should have created_at column"
  127. assert 'updated_at' in columns, "annotations should have updated_at column"
  128. # Verify foreign key constraint
  129. cursor.execute("PRAGMA foreign_key_list(annotations)")
  130. foreign_keys = cursor.fetchall()
  131. assert len(foreign_keys) > 0, "annotations should have foreign key constraint"
  132. assert foreign_keys[0][2] == 'tasks', "foreign key should reference tasks"
  133. conn.close()
  134. def test_export_jobs_table_schema(self, temp_db):
  135. """
  136. Test that export_jobs table has correct schema.
  137. Validates: Requirements 8.1
  138. """
  139. init_database()
  140. conn = sqlite3.connect(temp_db)
  141. cursor = conn.cursor()
  142. # Get table schema
  143. cursor.execute("PRAGMA table_info(export_jobs)")
  144. columns = {row[1]: row[2] for row in cursor.fetchall()}
  145. # Verify required columns exist
  146. assert 'id' in columns, "export_jobs should have id column"
  147. assert 'project_id' in columns, "export_jobs should have project_id column"
  148. assert 'format' in columns, "export_jobs should have format column"
  149. assert 'status' in columns, "export_jobs should have status column"
  150. assert 'file_path' in columns, "export_jobs should have file_path column"
  151. conn.close()
  152. def test_foreign_key_constraints_enabled(self, temp_db):
  153. """
  154. Test that foreign key constraints are enabled.
  155. Validates: Requirements 9.1
  156. """
  157. init_database()
  158. conn = sqlite3.connect(temp_db)
  159. cursor = conn.cursor()
  160. # Check if foreign keys are enabled
  161. cursor.execute("PRAGMA foreign_keys")
  162. result = cursor.fetchone()
  163. # Note: Foreign keys need to be enabled per connection
  164. # The init_database function enables them, but we need to enable for this connection too
  165. cursor.execute("PRAGMA foreign_keys = ON")
  166. cursor.execute("PRAGMA foreign_keys")
  167. result = cursor.fetchone()
  168. assert result[0] == 1, "foreign key constraints should be enabled"
  169. conn.close()
  170. def test_init_database_idempotent(self, temp_db):
  171. """
  172. Test that calling init_database multiple times is safe.
  173. Validates: Requirements 9.1
  174. """
  175. # Initialize database multiple times
  176. init_database()
  177. init_database()
  178. init_database()
  179. # Verify tables still exist and are not duplicated
  180. conn = sqlite3.connect(temp_db)
  181. cursor = conn.cursor()
  182. cursor.execute("""
  183. SELECT name FROM sqlite_master
  184. WHERE type='table'
  185. ORDER BY name
  186. """)
  187. tables = [row[0] for row in cursor.fetchall()]
  188. # Should have the expected tables
  189. assert 'projects' in tables
  190. assert 'tasks' in tables
  191. assert 'annotations' in tables
  192. assert 'users' in tables
  193. assert 'export_jobs' in tables
  194. conn.close()
  195. class TestConnectionManagement:
  196. """Test suite for database connection management."""
  197. def test_get_db_connection_context_manager(self, temp_db):
  198. """
  199. Test that get_db_connection works as context manager.
  200. Validates: Requirements 9.2
  201. """
  202. init_database()
  203. # Use context manager
  204. with get_db_connection() as conn:
  205. assert conn is not None, "connection should not be None"
  206. # Verify we can execute queries
  207. cursor = conn.cursor()
  208. cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
  209. tables = cursor.fetchall()
  210. assert len(tables) > 0, "should be able to query tables"
  211. def test_get_db_connection_commits_on_success(self, temp_db):
  212. """
  213. Test that get_db_connection commits changes on success.
  214. Validates: Requirements 9.2
  215. """
  216. init_database()
  217. # Insert data using context manager
  218. with get_db_connection() as conn:
  219. cursor = conn.cursor()
  220. cursor.execute("""
  221. INSERT INTO projects (id, name, description, config)
  222. VALUES (?, ?, ?, ?)
  223. """, ('test-1', 'Test Project', 'Description', '{}'))
  224. # Verify data was committed
  225. conn = sqlite3.connect(temp_db)
  226. cursor = conn.cursor()
  227. cursor.execute("SELECT * FROM projects WHERE id='test-1'")
  228. result = cursor.fetchone()
  229. assert result is not None, "data should be committed"
  230. conn.close()
  231. def test_get_db_connection_rolls_back_on_error(self, temp_db):
  232. """
  233. Test that get_db_connection rolls back on error.
  234. Validates: Requirements 9.2
  235. """
  236. init_database()
  237. # Try to insert data and raise an error
  238. try:
  239. with get_db_connection() as conn:
  240. cursor = conn.cursor()
  241. cursor.execute("""
  242. INSERT INTO projects (id, name, description, config)
  243. VALUES (?, ?, ?, ?)
  244. """, ('test-2', 'Test Project', 'Description', '{}'))
  245. # Raise an error before commit
  246. raise ValueError("Test error")
  247. except ValueError:
  248. pass
  249. # Verify data was rolled back
  250. conn = sqlite3.connect(temp_db)
  251. cursor = conn.cursor()
  252. cursor.execute("SELECT * FROM projects WHERE id='test-2'")
  253. result = cursor.fetchone()
  254. assert result is None, "data should be rolled back on error"
  255. conn.close()
  256. def test_get_db_function(self, temp_db):
  257. """
  258. Test that get_db function returns a valid connection.
  259. Validates: Requirements 9.2
  260. """
  261. init_database()
  262. # Get connection
  263. conn = get_db()
  264. try:
  265. assert conn is not None, "connection should not be None"
  266. # Verify we can execute queries
  267. cursor = conn.cursor()
  268. cursor.execute("""
  269. INSERT INTO projects (id, name, description, config)
  270. VALUES (?, ?, ?, ?)
  271. """, ('test-4', 'Test Project', 'Description', '{}'))
  272. conn.commit()
  273. cursor.execute("SELECT * FROM projects WHERE id=?", ('test-4',))
  274. row = cursor.fetchone()
  275. assert row is not None, "should be able to query data"
  276. finally:
  277. # Caller is responsible for closing
  278. conn.close()
  279. class TestDatabaseIntegrity:
  280. """Test suite for database integrity and constraints."""
  281. def test_cascade_delete_tasks_on_project_delete(self, temp_db):
  282. """
  283. Test that deleting a project cascades to delete tasks.
  284. Validates: Requirements 9.1
  285. """
  286. init_database()
  287. with get_db_connection() as conn:
  288. cursor = conn.cursor()
  289. # Insert project
  290. cursor.execute("""
  291. INSERT INTO projects (id, name, description, config)
  292. VALUES (?, ?, ?, ?)
  293. """, ('proj-1', 'Project 1', 'Description', '{}'))
  294. # Insert task
  295. cursor.execute("""
  296. INSERT INTO tasks (id, project_id, name, data, status)
  297. VALUES (?, ?, ?, ?, ?)
  298. """, ('task-1', 'proj-1', 'Task 1', '{}', 'pending'))
  299. # Verify task exists
  300. with get_db_connection() as conn:
  301. cursor = conn.cursor()
  302. cursor.execute("SELECT * FROM tasks WHERE id=?", ('task-1',))
  303. assert cursor.fetchone() is not None, "task should exist"
  304. # Delete project
  305. with get_db_connection() as conn:
  306. cursor = conn.cursor()
  307. cursor.execute("DELETE FROM projects WHERE id=?", ('proj-1',))
  308. # Verify task was deleted
  309. with get_db_connection() as conn:
  310. cursor = conn.cursor()
  311. cursor.execute("SELECT * FROM tasks WHERE id=?", ('task-1',))
  312. assert cursor.fetchone() is None, "task should be deleted with project"
  313. def test_cascade_delete_annotations_on_task_delete(self, temp_db):
  314. """
  315. Test that deleting a task cascades to delete annotations.
  316. Validates: Requirements 9.1
  317. """
  318. init_database()
  319. with get_db_connection() as conn:
  320. cursor = conn.cursor()
  321. # Insert project
  322. cursor.execute("""
  323. INSERT INTO projects (id, name, description, config)
  324. VALUES (?, ?, ?, ?)
  325. """, ('proj-2', 'Project 2', 'Description', '{}'))
  326. # Insert task
  327. cursor.execute("""
  328. INSERT INTO tasks (id, project_id, name, data, status)
  329. VALUES (?, ?, ?, ?, ?)
  330. """, ('task-2', 'proj-2', 'Task 2', '{}', 'pending'))
  331. # Insert annotation
  332. cursor.execute("""
  333. INSERT INTO annotations (id, task_id, user_id, result)
  334. VALUES (?, ?, ?, ?)
  335. """, ('ann-1', 'task-2', 'user-1', '{}'))
  336. # Verify annotation exists
  337. with get_db_connection() as conn:
  338. cursor = conn.cursor()
  339. cursor.execute("SELECT * FROM annotations WHERE id=?", ('ann-1',))
  340. assert cursor.fetchone() is not None, "annotation should exist"
  341. # Delete task
  342. with get_db_connection() as conn:
  343. cursor = conn.cursor()
  344. cursor.execute("DELETE FROM tasks WHERE id=?", ('task-2',))
  345. # Verify annotation was deleted
  346. with get_db_connection() as conn:
  347. cursor = conn.cursor()
  348. cursor.execute("SELECT * FROM annotations WHERE id=?", ('ann-1',))
  349. assert cursor.fetchone() is None, "annotation should be deleted with task"