| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432 |
- """
- Unit tests for database initialization and connection management.
- Tests database table creation and connection handling.
- Requirements: 9.1, 9.2
- """
- import os
- import sqlite3
- import tempfile
- import pytest
- from contextlib import contextmanager
- # Import database functions
- import sys
- sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
- from database import init_database, get_db_connection, get_db, db_config
- @pytest.fixture
- def temp_db():
- """
- Fixture that creates a temporary database for testing.
- Ensures each test runs with a clean database.
- """
- # Create a temporary database file
- fd, temp_db_path = tempfile.mkstemp(suffix='.db')
- os.close(fd)
-
- # Save original path
- original_db_path = db_config.db_path
-
- # Set the database path for testing
- db_config.db_path = temp_db_path
-
- yield temp_db_path
-
- # Cleanup: restore original path and remove temp file
- db_config.db_path = original_db_path
-
- try:
- os.unlink(temp_db_path)
- except OSError:
- pass
- class TestDatabaseInitialization:
- """Test suite for database initialization functionality."""
-
- def test_init_database_creates_tables(self, temp_db):
- """
- Test that init_database creates all required tables.
- Validates: Requirements 9.1
- """
- # Initialize the database
- init_database()
-
- # Connect to the database and check tables exist
- conn = sqlite3.connect(temp_db)
- cursor = conn.cursor()
-
- # Query for all tables
- cursor.execute("""
- SELECT name FROM sqlite_master
- WHERE type='table'
- ORDER BY name
- """)
- tables = [row[0] for row in cursor.fetchall()]
-
- # Verify all required tables exist
- assert 'projects' in tables, "projects table should exist"
- assert 'tasks' in tables, "tasks table should exist"
- assert 'annotations' in tables, "annotations table should exist"
- assert 'users' in tables, "users table should exist"
-
- conn.close()
-
- def test_projects_table_schema(self, temp_db):
- """
- Test that projects table has correct schema.
- Validates: Requirements 9.1
- """
- init_database()
-
- conn = sqlite3.connect(temp_db)
- cursor = conn.cursor()
-
- # Get table schema
- cursor.execute("PRAGMA table_info(projects)")
- columns = {row[1]: row[2] for row in cursor.fetchall()}
-
- # Verify required columns exist with correct types
- assert 'id' in columns, "projects should have id column"
- assert 'name' in columns, "projects should have name column"
- assert 'description' in columns, "projects should have description column"
- assert 'config' in columns, "projects should have config column"
- assert 'created_at' in columns, "projects should have created_at column"
-
- # Verify id is primary key
- cursor.execute("""
- SELECT sql FROM sqlite_master
- WHERE type='table' AND name='projects'
- """)
- schema = cursor.fetchone()[0]
- assert 'PRIMARY KEY' in schema, "id should be primary key"
-
- conn.close()
-
- def test_tasks_table_schema(self, temp_db):
- """
- Test that tasks table has correct schema.
- Validates: Requirements 9.1
- """
- init_database()
-
- conn = sqlite3.connect(temp_db)
- cursor = conn.cursor()
-
- # Get table schema
- cursor.execute("PRAGMA table_info(tasks)")
- columns = {row[1]: row[2] for row in cursor.fetchall()}
-
- # Verify required columns exist
- assert 'id' in columns, "tasks should have id column"
- assert 'project_id' in columns, "tasks should have project_id column"
- assert 'name' in columns, "tasks should have name column"
- assert 'data' in columns, "tasks should have data column"
- assert 'status' in columns, "tasks should have status column"
- assert 'assigned_to' in columns, "tasks should have assigned_to column"
- assert 'created_at' in columns, "tasks should have created_at column"
-
- # Verify foreign key constraint
- cursor.execute("PRAGMA foreign_key_list(tasks)")
- foreign_keys = cursor.fetchall()
- assert len(foreign_keys) > 0, "tasks should have foreign key constraint"
- assert foreign_keys[0][2] == 'projects', "foreign key should reference projects"
-
- conn.close()
-
- def test_annotations_table_schema(self, temp_db):
- """
- Test that annotations table has correct schema.
- Validates: Requirements 9.1
- """
- init_database()
-
- conn = sqlite3.connect(temp_db)
- cursor = conn.cursor()
-
- # Get table schema
- cursor.execute("PRAGMA table_info(annotations)")
- columns = {row[1]: row[2] for row in cursor.fetchall()}
-
- # Verify required columns exist
- assert 'id' in columns, "annotations should have id column"
- assert 'task_id' in columns, "annotations should have task_id column"
- assert 'user_id' in columns, "annotations should have user_id column"
- assert 'result' in columns, "annotations should have result column"
- assert 'created_at' in columns, "annotations should have created_at column"
- assert 'updated_at' in columns, "annotations should have updated_at column"
-
- # Verify foreign key constraint
- cursor.execute("PRAGMA foreign_key_list(annotations)")
- foreign_keys = cursor.fetchall()
- assert len(foreign_keys) > 0, "annotations should have foreign key constraint"
- assert foreign_keys[0][2] == 'tasks', "foreign key should reference tasks"
-
- conn.close()
-
- def test_export_jobs_table_schema(self, temp_db):
- """
- Test that export_jobs table has correct schema.
- Validates: Requirements 8.1
- """
- init_database()
-
- conn = sqlite3.connect(temp_db)
- cursor = conn.cursor()
-
- # Get table schema
- cursor.execute("PRAGMA table_info(export_jobs)")
- columns = {row[1]: row[2] for row in cursor.fetchall()}
-
- # Verify required columns exist
- assert 'id' in columns, "export_jobs should have id column"
- assert 'project_id' in columns, "export_jobs should have project_id column"
- assert 'format' in columns, "export_jobs should have format column"
- assert 'status' in columns, "export_jobs should have status column"
- assert 'file_path' in columns, "export_jobs should have file_path column"
-
- conn.close()
-
- def test_foreign_key_constraints_enabled(self, temp_db):
- """
- Test that foreign key constraints are enabled.
- Validates: Requirements 9.1
- """
- init_database()
-
- conn = sqlite3.connect(temp_db)
- cursor = conn.cursor()
-
- # Check if foreign keys are enabled
- cursor.execute("PRAGMA foreign_keys")
- result = cursor.fetchone()
-
- # Note: Foreign keys need to be enabled per connection
- # The init_database function enables them, but we need to enable for this connection too
- cursor.execute("PRAGMA foreign_keys = ON")
- cursor.execute("PRAGMA foreign_keys")
- result = cursor.fetchone()
- assert result[0] == 1, "foreign key constraints should be enabled"
-
- conn.close()
-
- def test_init_database_idempotent(self, temp_db):
- """
- Test that calling init_database multiple times is safe.
- Validates: Requirements 9.1
- """
- # Initialize database multiple times
- init_database()
- init_database()
- init_database()
-
- # Verify tables still exist and are not duplicated
- conn = sqlite3.connect(temp_db)
- cursor = conn.cursor()
-
- cursor.execute("""
- SELECT name FROM sqlite_master
- WHERE type='table'
- ORDER BY name
- """)
- tables = [row[0] for row in cursor.fetchall()]
-
- # Should have the expected tables
- assert 'projects' in tables
- assert 'tasks' in tables
- assert 'annotations' in tables
- assert 'users' in tables
- assert 'export_jobs' in tables
-
- conn.close()
- class TestConnectionManagement:
- """Test suite for database connection management."""
-
- def test_get_db_connection_context_manager(self, temp_db):
- """
- Test that get_db_connection works as context manager.
- Validates: Requirements 9.2
- """
- init_database()
-
- # Use context manager
- with get_db_connection() as conn:
- assert conn is not None, "connection should not be None"
-
- # Verify we can execute queries
- cursor = conn.cursor()
- cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
- tables = cursor.fetchall()
- assert len(tables) > 0, "should be able to query tables"
-
- def test_get_db_connection_commits_on_success(self, temp_db):
- """
- Test that get_db_connection commits changes on success.
- Validates: Requirements 9.2
- """
- init_database()
-
- # Insert data using context manager
- with get_db_connection() as conn:
- cursor = conn.cursor()
- cursor.execute("""
- INSERT INTO projects (id, name, description, config)
- VALUES (?, ?, ?, ?)
- """, ('test-1', 'Test Project', 'Description', '{}'))
-
- # Verify data was committed
- conn = sqlite3.connect(temp_db)
- cursor = conn.cursor()
- cursor.execute("SELECT * FROM projects WHERE id='test-1'")
- result = cursor.fetchone()
- assert result is not None, "data should be committed"
- conn.close()
-
- def test_get_db_connection_rolls_back_on_error(self, temp_db):
- """
- Test that get_db_connection rolls back on error.
- Validates: Requirements 9.2
- """
- init_database()
-
- # Try to insert data and raise an error
- try:
- with get_db_connection() as conn:
- cursor = conn.cursor()
- cursor.execute("""
- INSERT INTO projects (id, name, description, config)
- VALUES (?, ?, ?, ?)
- """, ('test-2', 'Test Project', 'Description', '{}'))
- # Raise an error before commit
- raise ValueError("Test error")
- except ValueError:
- pass
-
- # Verify data was rolled back
- conn = sqlite3.connect(temp_db)
- cursor = conn.cursor()
- cursor.execute("SELECT * FROM projects WHERE id='test-2'")
- result = cursor.fetchone()
- assert result is None, "data should be rolled back on error"
- conn.close()
-
- def test_get_db_function(self, temp_db):
- """
- Test that get_db function returns a valid connection.
- Validates: Requirements 9.2
- """
- init_database()
-
- # Get connection
- conn = get_db()
-
- try:
- assert conn is not None, "connection should not be None"
-
- # Verify we can execute queries
- cursor = conn.cursor()
- cursor.execute("""
- INSERT INTO projects (id, name, description, config)
- VALUES (?, ?, ?, ?)
- """, ('test-4', 'Test Project', 'Description', '{}'))
- conn.commit()
-
- cursor.execute("SELECT * FROM projects WHERE id=?", ('test-4',))
- row = cursor.fetchone()
- assert row is not None, "should be able to query data"
-
- finally:
- # Caller is responsible for closing
- conn.close()
- class TestDatabaseIntegrity:
- """Test suite for database integrity and constraints."""
-
- def test_cascade_delete_tasks_on_project_delete(self, temp_db):
- """
- Test that deleting a project cascades to delete tasks.
- Validates: Requirements 9.1
- """
- init_database()
-
- with get_db_connection() as conn:
- cursor = conn.cursor()
-
- # Insert project
- cursor.execute("""
- INSERT INTO projects (id, name, description, config)
- VALUES (?, ?, ?, ?)
- """, ('proj-1', 'Project 1', 'Description', '{}'))
-
- # Insert task
- cursor.execute("""
- INSERT INTO tasks (id, project_id, name, data, status)
- VALUES (?, ?, ?, ?, ?)
- """, ('task-1', 'proj-1', 'Task 1', '{}', 'pending'))
-
- # Verify task exists
- with get_db_connection() as conn:
- cursor = conn.cursor()
- cursor.execute("SELECT * FROM tasks WHERE id=?", ('task-1',))
- assert cursor.fetchone() is not None, "task should exist"
-
- # Delete project
- with get_db_connection() as conn:
- cursor = conn.cursor()
- cursor.execute("DELETE FROM projects WHERE id=?", ('proj-1',))
-
- # Verify task was deleted
- with get_db_connection() as conn:
- cursor = conn.cursor()
- cursor.execute("SELECT * FROM tasks WHERE id=?", ('task-1',))
- assert cursor.fetchone() is None, "task should be deleted with project"
-
- def test_cascade_delete_annotations_on_task_delete(self, temp_db):
- """
- Test that deleting a task cascades to delete annotations.
- Validates: Requirements 9.1
- """
- init_database()
-
- with get_db_connection() as conn:
- cursor = conn.cursor()
-
- # Insert project
- cursor.execute("""
- INSERT INTO projects (id, name, description, config)
- VALUES (?, ?, ?, ?)
- """, ('proj-2', 'Project 2', 'Description', '{}'))
-
- # Insert task
- cursor.execute("""
- INSERT INTO tasks (id, project_id, name, data, status)
- VALUES (?, ?, ?, ?, ?)
- """, ('task-2', 'proj-2', 'Task 2', '{}', 'pending'))
-
- # Insert annotation
- cursor.execute("""
- INSERT INTO annotations (id, task_id, user_id, result)
- VALUES (?, ?, ?, ?)
- """, ('ann-1', 'task-2', 'user-1', '{}'))
-
- # Verify annotation exists
- with get_db_connection() as conn:
- cursor = conn.cursor()
- cursor.execute("SELECT * FROM annotations WHERE id=?", ('ann-1',))
- assert cursor.fetchone() is not None, "annotation should exist"
-
- # Delete task
- with get_db_connection() as conn:
- cursor = conn.cursor()
- cursor.execute("DELETE FROM tasks WHERE id=?", ('task-2',))
-
- # Verify annotation was deleted
- with get_db_connection() as conn:
- cursor = conn.cursor()
- cursor.execute("SELECT * FROM annotations WHERE id=?", ('ann-1',))
- assert cursor.fetchone() is None, "annotation should be deleted with task"
|