""" Database connection and initialization module. MySQL only - SQLite support removed. """ import logging from contextlib import contextmanager from typing import Generator, Optional from config import settings import pymysql logger = logging.getLogger(__name__) class DatabaseConfig: """Database configuration holder (MySQL only).""" def __init__(self): self.host = getattr(settings, 'MYSQL_HOST', 'localhost') self.port = getattr(settings, 'MYSQL_PORT', 3306) self.user = getattr(settings, 'MYSQL_USER', 'root') self.password = getattr(settings, 'MYSQL_PASSWORD', '') self.database = getattr(settings, 'MYSQL_DATABASE', 'annotation_platform') db_config = DatabaseConfig() def _get_mysql_connection(): """Get MySQL connection.""" conn = pymysql.connect( host=db_config.host, port=db_config.port, user=db_config.user, password=db_config.password, database=db_config.database, charset='utf8mb4', cursorclass=pymysql.cursors.DictCursor, autocommit=False ) return conn class RowWrapper: """Wrapper to provide consistent row access.""" def __init__(self, row): self._row = row def __getitem__(self, key): return self._row[key] def keys(self): return self._row.keys() class CursorWrapper: """Wrapper to provide consistent cursor interface.""" def __init__(self, cursor): self._cursor = cursor def execute(self, sql: str, params: tuple = None): """Execute SQL with parameter conversion.""" # Convert ? placeholders to %s for MySQL sql = sql.replace('?', '%s') if params: self._cursor.execute(sql, params) else: self._cursor.execute(sql) return self def fetchone(self) -> Optional[RowWrapper]: row = self._cursor.fetchone() if row is None: return None return RowWrapper(row) def fetchall(self) -> list: rows = self._cursor.fetchall() return [RowWrapper(row) for row in rows] @property def lastrowid(self): return self._cursor.lastrowid @property def rowcount(self): return self._cursor.rowcount class ConnectionWrapper: """Wrapper to provide consistent connection interface.""" def __init__(self, conn): self._conn = conn def cursor(self) -> CursorWrapper: return CursorWrapper(self._conn.cursor()) def commit(self): self._conn.commit() def rollback(self): self._conn.rollback() def close(self): self._conn.close() def execute(self, sql: str, params: tuple = None): cursor = self.cursor() cursor.execute(sql, params) return cursor @contextmanager def get_db_connection() -> Generator[ConnectionWrapper, None, None]: """ Context manager for database connections. Ensures proper connection cleanup. """ conn = _get_mysql_connection() wrapped = ConnectionWrapper(conn) try: yield wrapped wrapped.commit() except Exception: wrapped.rollback() raise finally: wrapped.close() def init_database() -> None: """Initialize MySQL database and create tables if they don't exist.""" # First, create database if not exists conn = pymysql.connect( host=db_config.host, port=db_config.port, user=db_config.user, password=db_config.password, charset='utf8mb4' ) try: with conn.cursor() as cursor: cursor.execute(f"CREATE DATABASE IF NOT EXISTS `{db_config.database}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci") conn.commit() finally: conn.close() # Now create tables with get_db_connection() as conn: cursor = conn.cursor() # Create users table cursor.execute(""" CREATE TABLE IF NOT EXISTS users ( id VARCHAR(36) PRIMARY KEY, username VARCHAR(255) NOT NULL UNIQUE, email VARCHAR(255) NOT NULL UNIQUE, password_hash VARCHAR(255) NOT NULL, role VARCHAR(50) NOT NULL DEFAULT 'annotator', oauth_provider VARCHAR(50), oauth_id VARCHAR(255), created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, INDEX idx_users_username (username), INDEX idx_users_email (email), INDEX idx_users_oauth (oauth_provider, oauth_id) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci """) # Create projects table cursor.execute(""" CREATE TABLE IF NOT EXISTS projects ( id VARCHAR(36) PRIMARY KEY, name VARCHAR(255) NOT NULL, description TEXT, config TEXT NOT NULL, status VARCHAR(20) DEFAULT 'draft', source VARCHAR(20) DEFAULT 'internal', task_type VARCHAR(50), external_id VARCHAR(100), created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, INDEX idx_projects_status (status), INDEX idx_projects_source (source), INDEX idx_projects_external_id (external_id) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci """) # Create tasks table cursor.execute(""" CREATE TABLE IF NOT EXISTS tasks ( id VARCHAR(36) PRIMARY KEY, project_id VARCHAR(36) NOT NULL, name VARCHAR(255) NOT NULL, data LONGTEXT NOT NULL, status VARCHAR(50) DEFAULT 'pending', assigned_to VARCHAR(36), created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE, INDEX idx_tasks_project (project_id), INDEX idx_tasks_status (status) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci """) # Create annotations table cursor.execute(""" CREATE TABLE IF NOT EXISTS annotations ( id VARCHAR(36) PRIMARY KEY, task_id VARCHAR(36) NOT NULL, user_id VARCHAR(36) NOT NULL, result LONGTEXT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, FOREIGN KEY (task_id) REFERENCES tasks(id) ON DELETE CASCADE, INDEX idx_annotations_task (task_id), INDEX idx_annotations_user (user_id) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci """) # Create export_jobs table cursor.execute(""" CREATE TABLE IF NOT EXISTS export_jobs ( id VARCHAR(36) PRIMARY KEY, project_id VARCHAR(36) NOT NULL, format VARCHAR(50) NOT NULL, status VARCHAR(50) DEFAULT 'pending', status_filter VARCHAR(50) DEFAULT 'all', include_metadata BOOLEAN DEFAULT TRUE, file_path TEXT, error_message TEXT, created_by VARCHAR(36), created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, completed_at TIMESTAMP, total_tasks INT DEFAULT 0, exported_tasks INT DEFAULT 0, FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE, INDEX idx_export_jobs_project (project_id), INDEX idx_export_jobs_status (status) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci """) logger.info("MySQL 数据库初始化完成") def get_db(): """ Get a database connection (legacy support). Note: Caller is responsible for closing the connection. """ conn = _get_mysql_connection() return ConnectionWrapper(conn)