|
|
@@ -0,0 +1,476 @@
|
|
|
+"""
|
|
|
+Unit tests for database initialization and connection management.
|
|
|
+Tests database table creation and connection handling.
|
|
|
+
|
|
|
+Requirements: 9.1, 9.2
|
|
|
+"""
|
|
|
+import os
|
|
|
+import sqlite3
|
|
|
+import tempfile
|
|
|
+import pytest
|
|
|
+from contextlib import contextmanager
|
|
|
+
|
|
|
+# Import database functions
|
|
|
+import sys
|
|
|
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
|
|
|
+from database import init_database, get_db_connection, get_db, DB_PATH
|
|
|
+
|
|
|
+
|
|
|
+@pytest.fixture
|
|
|
+def temp_db():
|
|
|
+ """
|
|
|
+ Fixture that creates a temporary database for testing.
|
|
|
+ Ensures each test runs with a clean database.
|
|
|
+ """
|
|
|
+ # Create a temporary database file
|
|
|
+ fd, temp_db_path = tempfile.mkstemp(suffix='.db')
|
|
|
+ os.close(fd)
|
|
|
+
|
|
|
+ # Set the database path for testing
|
|
|
+ original_db_path = os.environ.get('DATABASE_PATH')
|
|
|
+ os.environ['DATABASE_PATH'] = temp_db_path
|
|
|
+
|
|
|
+ # Force reload of database module to pick up new path
|
|
|
+ import database
|
|
|
+ database.DB_PATH = temp_db_path
|
|
|
+
|
|
|
+ yield temp_db_path
|
|
|
+
|
|
|
+ # Cleanup: restore original path and remove temp file
|
|
|
+ if original_db_path:
|
|
|
+ os.environ['DATABASE_PATH'] = original_db_path
|
|
|
+ else:
|
|
|
+ os.environ.pop('DATABASE_PATH', None)
|
|
|
+
|
|
|
+ database.DB_PATH = original_db_path or "annotation_platform.db"
|
|
|
+
|
|
|
+ try:
|
|
|
+ os.unlink(temp_db_path)
|
|
|
+ except OSError:
|
|
|
+ pass
|
|
|
+
|
|
|
+
|
|
|
+class TestDatabaseInitialization:
|
|
|
+ """Test suite for database initialization functionality."""
|
|
|
+
|
|
|
+ def test_init_database_creates_tables(self, temp_db):
|
|
|
+ """
|
|
|
+ Test that init_database creates all required tables.
|
|
|
+ Validates: Requirements 9.1
|
|
|
+ """
|
|
|
+ # Initialize the database
|
|
|
+ init_database()
|
|
|
+
|
|
|
+ # Connect to the database and check tables exist
|
|
|
+ conn = sqlite3.connect(temp_db)
|
|
|
+ cursor = conn.cursor()
|
|
|
+
|
|
|
+ # Query for all tables
|
|
|
+ cursor.execute("""
|
|
|
+ SELECT name FROM sqlite_master
|
|
|
+ WHERE type='table'
|
|
|
+ ORDER BY name
|
|
|
+ """)
|
|
|
+ tables = [row[0] for row in cursor.fetchall()]
|
|
|
+
|
|
|
+ # Verify all required tables exist
|
|
|
+ assert 'projects' in tables, "projects table should exist"
|
|
|
+ assert 'tasks' in tables, "tasks table should exist"
|
|
|
+ assert 'annotations' in tables, "annotations table should exist"
|
|
|
+
|
|
|
+ conn.close()
|
|
|
+
|
|
|
+ def test_projects_table_schema(self, temp_db):
|
|
|
+ """
|
|
|
+ Test that projects table has correct schema.
|
|
|
+ Validates: Requirements 9.1
|
|
|
+ """
|
|
|
+ init_database()
|
|
|
+
|
|
|
+ conn = sqlite3.connect(temp_db)
|
|
|
+ cursor = conn.cursor()
|
|
|
+
|
|
|
+ # Get table schema
|
|
|
+ cursor.execute("PRAGMA table_info(projects)")
|
|
|
+ columns = {row[1]: row[2] for row in cursor.fetchall()}
|
|
|
+
|
|
|
+ # Verify required columns exist with correct types
|
|
|
+ assert 'id' in columns, "projects should have id column"
|
|
|
+ assert 'name' in columns, "projects should have name column"
|
|
|
+ assert 'description' in columns, "projects should have description column"
|
|
|
+ assert 'config' in columns, "projects should have config column"
|
|
|
+ assert 'created_at' in columns, "projects should have created_at column"
|
|
|
+
|
|
|
+ # Verify id is primary key
|
|
|
+ cursor.execute("""
|
|
|
+ SELECT sql FROM sqlite_master
|
|
|
+ WHERE type='table' AND name='projects'
|
|
|
+ """)
|
|
|
+ schema = cursor.fetchone()[0]
|
|
|
+ assert 'PRIMARY KEY' in schema, "id should be primary key"
|
|
|
+
|
|
|
+ conn.close()
|
|
|
+
|
|
|
+ def test_tasks_table_schema(self, temp_db):
|
|
|
+ """
|
|
|
+ Test that tasks table has correct schema.
|
|
|
+ Validates: Requirements 9.1
|
|
|
+ """
|
|
|
+ init_database()
|
|
|
+
|
|
|
+ conn = sqlite3.connect(temp_db)
|
|
|
+ cursor = conn.cursor()
|
|
|
+
|
|
|
+ # Get table schema
|
|
|
+ cursor.execute("PRAGMA table_info(tasks)")
|
|
|
+ columns = {row[1]: row[2] for row in cursor.fetchall()}
|
|
|
+
|
|
|
+ # Verify required columns exist
|
|
|
+ assert 'id' in columns, "tasks should have id column"
|
|
|
+ assert 'project_id' in columns, "tasks should have project_id column"
|
|
|
+ assert 'name' in columns, "tasks should have name column"
|
|
|
+ assert 'data' in columns, "tasks should have data column"
|
|
|
+ assert 'status' in columns, "tasks should have status column"
|
|
|
+ assert 'assigned_to' in columns, "tasks should have assigned_to column"
|
|
|
+ assert 'created_at' in columns, "tasks should have created_at column"
|
|
|
+
|
|
|
+ # Verify foreign key constraint
|
|
|
+ cursor.execute("PRAGMA foreign_key_list(tasks)")
|
|
|
+ foreign_keys = cursor.fetchall()
|
|
|
+ assert len(foreign_keys) > 0, "tasks should have foreign key constraint"
|
|
|
+ assert foreign_keys[0][2] == 'projects', "foreign key should reference projects"
|
|
|
+
|
|
|
+ conn.close()
|
|
|
+
|
|
|
+ def test_annotations_table_schema(self, temp_db):
|
|
|
+ """
|
|
|
+ Test that annotations table has correct schema.
|
|
|
+ Validates: Requirements 9.1
|
|
|
+ """
|
|
|
+ init_database()
|
|
|
+
|
|
|
+ conn = sqlite3.connect(temp_db)
|
|
|
+ cursor = conn.cursor()
|
|
|
+
|
|
|
+ # Get table schema
|
|
|
+ cursor.execute("PRAGMA table_info(annotations)")
|
|
|
+ columns = {row[1]: row[2] for row in cursor.fetchall()}
|
|
|
+
|
|
|
+ # Verify required columns exist
|
|
|
+ assert 'id' in columns, "annotations should have id column"
|
|
|
+ assert 'task_id' in columns, "annotations should have task_id column"
|
|
|
+ assert 'user_id' in columns, "annotations should have user_id column"
|
|
|
+ assert 'result' in columns, "annotations should have result column"
|
|
|
+ assert 'created_at' in columns, "annotations should have created_at column"
|
|
|
+ assert 'updated_at' in columns, "annotations should have updated_at column"
|
|
|
+
|
|
|
+ # Verify foreign key constraint
|
|
|
+ cursor.execute("PRAGMA foreign_key_list(annotations)")
|
|
|
+ foreign_keys = cursor.fetchall()
|
|
|
+ assert len(foreign_keys) > 0, "annotations should have foreign key constraint"
|
|
|
+ assert foreign_keys[0][2] == 'tasks', "foreign key should reference tasks"
|
|
|
+
|
|
|
+ conn.close()
|
|
|
+
|
|
|
+ def test_foreign_key_constraints_enabled(self, temp_db):
|
|
|
+ """
|
|
|
+ Test that foreign key constraints are enabled.
|
|
|
+ Validates: Requirements 9.1
|
|
|
+ """
|
|
|
+ init_database()
|
|
|
+
|
|
|
+ conn = sqlite3.connect(temp_db)
|
|
|
+ cursor = conn.cursor()
|
|
|
+
|
|
|
+ # Check if foreign keys are enabled
|
|
|
+ cursor.execute("PRAGMA foreign_keys")
|
|
|
+ result = cursor.fetchone()
|
|
|
+
|
|
|
+ # Note: Foreign keys need to be enabled per connection
|
|
|
+ # The init_database function enables them, but we need to enable for this connection too
|
|
|
+ cursor.execute("PRAGMA foreign_keys = ON")
|
|
|
+ cursor.execute("PRAGMA foreign_keys")
|
|
|
+ result = cursor.fetchone()
|
|
|
+ assert result[0] == 1, "foreign key constraints should be enabled"
|
|
|
+
|
|
|
+ conn.close()
|
|
|
+
|
|
|
+ def test_init_database_idempotent(self, temp_db):
|
|
|
+ """
|
|
|
+ Test that calling init_database multiple times is safe.
|
|
|
+ Validates: Requirements 9.1
|
|
|
+ """
|
|
|
+ # Initialize database multiple times
|
|
|
+ init_database()
|
|
|
+ init_database()
|
|
|
+ init_database()
|
|
|
+
|
|
|
+ # Verify tables still exist and are not duplicated
|
|
|
+ conn = sqlite3.connect(temp_db)
|
|
|
+ cursor = conn.cursor()
|
|
|
+
|
|
|
+ cursor.execute("""
|
|
|
+ SELECT name FROM sqlite_master
|
|
|
+ WHERE type='table'
|
|
|
+ ORDER BY name
|
|
|
+ """)
|
|
|
+ tables = [row[0] for row in cursor.fetchall()]
|
|
|
+
|
|
|
+ # Should have exactly 3 tables
|
|
|
+ assert len(tables) == 3, "should have exactly 3 tables"
|
|
|
+ assert 'projects' in tables
|
|
|
+ assert 'tasks' in tables
|
|
|
+ assert 'annotations' in tables
|
|
|
+
|
|
|
+ conn.close()
|
|
|
+
|
|
|
+
|
|
|
+class TestConnectionManagement:
|
|
|
+ """Test suite for database connection management."""
|
|
|
+
|
|
|
+ def test_get_db_connection_context_manager(self, temp_db):
|
|
|
+ """
|
|
|
+ Test that get_db_connection works as context manager.
|
|
|
+ Validates: Requirements 9.2
|
|
|
+ """
|
|
|
+ init_database()
|
|
|
+
|
|
|
+ # Use context manager
|
|
|
+ with get_db_connection() as conn:
|
|
|
+ assert conn is not None, "connection should not be None"
|
|
|
+ assert isinstance(conn, sqlite3.Connection), "should return Connection object"
|
|
|
+
|
|
|
+ # Verify we can execute queries
|
|
|
+ cursor = conn.cursor()
|
|
|
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
|
|
+ tables = cursor.fetchall()
|
|
|
+ assert len(tables) > 0, "should be able to query tables"
|
|
|
+
|
|
|
+ def test_get_db_connection_commits_on_success(self, temp_db):
|
|
|
+ """
|
|
|
+ Test that get_db_connection commits changes on success.
|
|
|
+ Validates: Requirements 9.2
|
|
|
+ """
|
|
|
+ init_database()
|
|
|
+
|
|
|
+ # Insert data using context manager
|
|
|
+ with get_db_connection() as conn:
|
|
|
+ cursor = conn.cursor()
|
|
|
+ cursor.execute("""
|
|
|
+ INSERT INTO projects (id, name, description, config)
|
|
|
+ VALUES ('test-1', 'Test Project', 'Description', '{}')
|
|
|
+ """)
|
|
|
+
|
|
|
+ # Verify data was committed
|
|
|
+ conn = sqlite3.connect(temp_db)
|
|
|
+ cursor = conn.cursor()
|
|
|
+ cursor.execute("SELECT * FROM projects WHERE id='test-1'")
|
|
|
+ result = cursor.fetchone()
|
|
|
+ assert result is not None, "data should be committed"
|
|
|
+ conn.close()
|
|
|
+
|
|
|
+ def test_get_db_connection_rolls_back_on_error(self, temp_db):
|
|
|
+ """
|
|
|
+ Test that get_db_connection rolls back on error.
|
|
|
+ Validates: Requirements 9.2
|
|
|
+ """
|
|
|
+ init_database()
|
|
|
+
|
|
|
+ # Try to insert data and raise an error
|
|
|
+ try:
|
|
|
+ with get_db_connection() as conn:
|
|
|
+ cursor = conn.cursor()
|
|
|
+ cursor.execute("""
|
|
|
+ INSERT INTO projects (id, name, description, config)
|
|
|
+ VALUES ('test-2', 'Test Project', 'Description', '{}')
|
|
|
+ """)
|
|
|
+ # Raise an error before commit
|
|
|
+ raise ValueError("Test error")
|
|
|
+ except ValueError:
|
|
|
+ pass
|
|
|
+
|
|
|
+ # Verify data was rolled back
|
|
|
+ conn = sqlite3.connect(temp_db)
|
|
|
+ cursor = conn.cursor()
|
|
|
+ cursor.execute("SELECT * FROM projects WHERE id='test-2'")
|
|
|
+ result = cursor.fetchone()
|
|
|
+ assert result is None, "data should be rolled back on error"
|
|
|
+ conn.close()
|
|
|
+
|
|
|
+ def test_get_db_connection_closes_connection(self, temp_db):
|
|
|
+ """
|
|
|
+ Test that get_db_connection closes connection after use.
|
|
|
+ Validates: Requirements 9.2
|
|
|
+ """
|
|
|
+ init_database()
|
|
|
+
|
|
|
+ conn_ref = None
|
|
|
+ with get_db_connection() as conn:
|
|
|
+ conn_ref = conn
|
|
|
+ assert conn is not None
|
|
|
+
|
|
|
+ # Try to use connection after context manager exits
|
|
|
+ # This should fail because connection is closed
|
|
|
+ with pytest.raises(sqlite3.ProgrammingError):
|
|
|
+ cursor = conn_ref.cursor()
|
|
|
+ cursor.execute("SELECT 1")
|
|
|
+
|
|
|
+ def test_get_db_connection_enables_row_factory(self, temp_db):
|
|
|
+ """
|
|
|
+ Test that get_db_connection enables row factory for column access.
|
|
|
+ Validates: Requirements 9.2
|
|
|
+ """
|
|
|
+ init_database()
|
|
|
+
|
|
|
+ # Insert test data
|
|
|
+ with get_db_connection() as conn:
|
|
|
+ cursor = conn.cursor()
|
|
|
+ cursor.execute("""
|
|
|
+ INSERT INTO projects (id, name, description, config)
|
|
|
+ VALUES ('test-3', 'Test Project', 'Description', '{}')
|
|
|
+ """)
|
|
|
+
|
|
|
+ # Query and verify row factory works
|
|
|
+ with get_db_connection() as conn:
|
|
|
+ cursor = conn.cursor()
|
|
|
+ cursor.execute("SELECT * FROM projects WHERE id='test-3'")
|
|
|
+ row = cursor.fetchone()
|
|
|
+
|
|
|
+ # Should be able to access columns by name
|
|
|
+ assert row['id'] == 'test-3', "should access column by name"
|
|
|
+ assert row['name'] == 'Test Project', "should access column by name"
|
|
|
+
|
|
|
+ def test_get_db_function(self, temp_db):
|
|
|
+ """
|
|
|
+ Test that get_db function returns a valid connection.
|
|
|
+ Validates: Requirements 9.2
|
|
|
+ """
|
|
|
+ init_database()
|
|
|
+
|
|
|
+ # Get connection
|
|
|
+ conn = get_db()
|
|
|
+
|
|
|
+ try:
|
|
|
+ assert conn is not None, "connection should not be None"
|
|
|
+ assert isinstance(conn, sqlite3.Connection), "should return Connection object"
|
|
|
+
|
|
|
+ # Verify row factory is enabled
|
|
|
+ cursor = conn.cursor()
|
|
|
+ cursor.execute("""
|
|
|
+ INSERT INTO projects (id, name, description, config)
|
|
|
+ VALUES ('test-4', 'Test Project', 'Description', '{}')
|
|
|
+ """)
|
|
|
+ conn.commit()
|
|
|
+
|
|
|
+ cursor.execute("SELECT * FROM projects WHERE id='test-4'")
|
|
|
+ row = cursor.fetchone()
|
|
|
+ assert row['id'] == 'test-4', "row factory should be enabled"
|
|
|
+
|
|
|
+ finally:
|
|
|
+ # Caller is responsible for closing
|
|
|
+ conn.close()
|
|
|
+
|
|
|
+ def test_get_db_enables_foreign_keys(self, temp_db):
|
|
|
+ """
|
|
|
+ Test that get_db enables foreign key constraints.
|
|
|
+ Validates: Requirements 9.2
|
|
|
+ """
|
|
|
+ init_database()
|
|
|
+
|
|
|
+ conn = get_db()
|
|
|
+
|
|
|
+ try:
|
|
|
+ cursor = conn.cursor()
|
|
|
+ cursor.execute("PRAGMA foreign_keys")
|
|
|
+ result = cursor.fetchone()
|
|
|
+ assert result[0] == 1, "foreign keys should be enabled"
|
|
|
+ finally:
|
|
|
+ conn.close()
|
|
|
+
|
|
|
+
|
|
|
+class TestDatabaseIntegrity:
|
|
|
+ """Test suite for database integrity and constraints."""
|
|
|
+
|
|
|
+ def test_cascade_delete_tasks_on_project_delete(self, temp_db):
|
|
|
+ """
|
|
|
+ Test that deleting a project cascades to delete tasks.
|
|
|
+ Validates: Requirements 9.1
|
|
|
+ """
|
|
|
+ init_database()
|
|
|
+
|
|
|
+ with get_db_connection() as conn:
|
|
|
+ cursor = conn.cursor()
|
|
|
+
|
|
|
+ # Insert project
|
|
|
+ cursor.execute("""
|
|
|
+ INSERT INTO projects (id, name, description, config)
|
|
|
+ VALUES ('proj-1', 'Project 1', 'Description', '{}')
|
|
|
+ """)
|
|
|
+
|
|
|
+ # Insert task
|
|
|
+ cursor.execute("""
|
|
|
+ INSERT INTO tasks (id, project_id, name, data, status)
|
|
|
+ VALUES ('task-1', 'proj-1', 'Task 1', '{}', 'pending')
|
|
|
+ """)
|
|
|
+
|
|
|
+ # Verify task exists
|
|
|
+ with get_db_connection() as conn:
|
|
|
+ cursor = conn.cursor()
|
|
|
+ cursor.execute("SELECT * FROM tasks WHERE id='task-1'")
|
|
|
+ assert cursor.fetchone() is not None, "task should exist"
|
|
|
+
|
|
|
+ # Delete project
|
|
|
+ with get_db_connection() as conn:
|
|
|
+ cursor = conn.cursor()
|
|
|
+ cursor.execute("DELETE FROM projects WHERE id='proj-1'")
|
|
|
+
|
|
|
+ # Verify task was deleted
|
|
|
+ with get_db_connection() as conn:
|
|
|
+ cursor = conn.cursor()
|
|
|
+ cursor.execute("SELECT * FROM tasks WHERE id='task-1'")
|
|
|
+ assert cursor.fetchone() is None, "task should be deleted with project"
|
|
|
+
|
|
|
+ def test_cascade_delete_annotations_on_task_delete(self, temp_db):
|
|
|
+ """
|
|
|
+ Test that deleting a task cascades to delete annotations.
|
|
|
+ Validates: Requirements 9.1
|
|
|
+ """
|
|
|
+ init_database()
|
|
|
+
|
|
|
+ with get_db_connection() as conn:
|
|
|
+ cursor = conn.cursor()
|
|
|
+
|
|
|
+ # Insert project
|
|
|
+ cursor.execute("""
|
|
|
+ INSERT INTO projects (id, name, description, config)
|
|
|
+ VALUES ('proj-2', 'Project 2', 'Description', '{}')
|
|
|
+ """)
|
|
|
+
|
|
|
+ # Insert task
|
|
|
+ cursor.execute("""
|
|
|
+ INSERT INTO tasks (id, project_id, name, data, status)
|
|
|
+ VALUES ('task-2', 'proj-2', 'Task 2', '{}', 'pending')
|
|
|
+ """)
|
|
|
+
|
|
|
+ # Insert annotation
|
|
|
+ cursor.execute("""
|
|
|
+ INSERT INTO annotations (id, task_id, user_id, result)
|
|
|
+ VALUES ('ann-1', 'task-2', 'user-1', '{}')
|
|
|
+ """)
|
|
|
+
|
|
|
+ # Verify annotation exists
|
|
|
+ with get_db_connection() as conn:
|
|
|
+ cursor = conn.cursor()
|
|
|
+ cursor.execute("SELECT * FROM annotations WHERE id='ann-1'")
|
|
|
+ assert cursor.fetchone() is not None, "annotation should exist"
|
|
|
+
|
|
|
+ # Delete task
|
|
|
+ with get_db_connection() as conn:
|
|
|
+ cursor = conn.cursor()
|
|
|
+ cursor.execute("DELETE FROM tasks WHERE id='task-2'")
|
|
|
+
|
|
|
+ # Verify annotation was deleted
|
|
|
+ with get_db_connection() as conn:
|
|
|
+ cursor = conn.cursor()
|
|
|
+ cursor.execute("SELECT * FROM annotations WHERE id='ann-1'")
|
|
|
+ assert cursor.fetchone() is None, "annotation should be deleted with task"
|