""" Database connection and initialization module. Supports both SQLite and MySQL databases. """ import os import logging from contextlib import contextmanager from typing import Generator, Any, Optional from config import settings logger = logging.getLogger(__name__) class DatabaseConfig: """Database configuration holder.""" def __init__(self): self.db_type = getattr(settings, 'DATABASE_TYPE', 'sqlite') if self.db_type == 'mysql': self.host = getattr(settings, 'MYSQL_HOST', 'localhost') self.port = getattr(settings, 'MYSQL_PORT', 3306) self.user = getattr(settings, 'MYSQL_USER', 'root') self.password = getattr(settings, 'MYSQL_PASSWORD', '') self.database = getattr(settings, 'MYSQL_DATABASE', 'annotation_platform') else: self.db_path = getattr(settings, 'DATABASE_PATH', 'annotation_platform.db') db_config = DatabaseConfig() def _get_mysql_connection(): """Get MySQL connection.""" import pymysql conn = pymysql.connect( host=db_config.host, port=db_config.port, user=db_config.user, password=db_config.password, database=db_config.database, charset='utf8mb4', cursorclass=pymysql.cursors.DictCursor, autocommit=False ) return conn def _get_sqlite_connection(): """Get SQLite connection.""" import sqlite3 conn = sqlite3.connect(db_config.db_path) conn.row_factory = sqlite3.Row conn.execute("PRAGMA foreign_keys = ON") return conn class RowWrapper: """Wrapper to provide consistent row access for both SQLite and MySQL.""" def __init__(self, row, db_type: str): self._row = row self._db_type = db_type def __getitem__(self, key): if self._db_type == 'mysql': return self._row[key] else: return self._row[key] def keys(self): if self._db_type == 'mysql': return self._row.keys() else: return self._row.keys() class CursorWrapper: """Wrapper to provide consistent cursor interface for both databases.""" def __init__(self, cursor, db_type: str): self._cursor = cursor self._db_type = db_type def execute(self, sql: str, params: tuple = None): """Execute SQL with parameter conversion.""" if self._db_type == 'mysql': # Convert ? placeholders to %s for MySQL sql = sql.replace('?', '%s') if params: self._cursor.execute(sql, params) else: self._cursor.execute(sql) return self def fetchone(self) -> Optional[RowWrapper]: row = self._cursor.fetchone() if row is None: return None return RowWrapper(row, self._db_type) def fetchall(self) -> list: rows = self._cursor.fetchall() return [RowWrapper(row, self._db_type) for row in rows] @property def lastrowid(self): return self._cursor.lastrowid @property def rowcount(self): return self._cursor.rowcount class ConnectionWrapper: """Wrapper to provide consistent connection interface.""" def __init__(self, conn, db_type: str): self._conn = conn self._db_type = db_type def cursor(self) -> CursorWrapper: return CursorWrapper(self._conn.cursor(), self._db_type) def commit(self): self._conn.commit() def rollback(self): self._conn.rollback() def close(self): self._conn.close() def execute(self, sql: str, params: tuple = None): cursor = self.cursor() cursor.execute(sql, params) return cursor @contextmanager def get_db_connection() -> Generator[ConnectionWrapper, None, None]: """ Context manager for database connections. Ensures proper connection cleanup. """ if db_config.db_type == 'mysql': conn = _get_mysql_connection() else: conn = _get_sqlite_connection() wrapped = ConnectionWrapper(conn, db_config.db_type) try: yield wrapped wrapped.commit() except Exception: wrapped.rollback() raise finally: wrapped.close() def init_database() -> None: """ Initialize database and create tables if they don't exist. """ if db_config.db_type == 'mysql': _init_mysql_database() else: _init_sqlite_database() def _init_mysql_database() -> None: """Initialize MySQL database tables.""" import pymysql # First, create database if not exists conn = pymysql.connect( host=db_config.host, port=db_config.port, user=db_config.user, password=db_config.password, charset='utf8mb4' ) try: with conn.cursor() as cursor: cursor.execute(f"CREATE DATABASE IF NOT EXISTS `{db_config.database}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci") conn.commit() finally: conn.close() # Now create tables with get_db_connection() as conn: cursor = conn.cursor() # Create users table cursor.execute(""" CREATE TABLE IF NOT EXISTS users ( id VARCHAR(36) PRIMARY KEY, username VARCHAR(255) NOT NULL UNIQUE, email VARCHAR(255) NOT NULL UNIQUE, password_hash VARCHAR(255) NOT NULL, role VARCHAR(50) NOT NULL DEFAULT 'annotator', oauth_provider VARCHAR(50), oauth_id VARCHAR(255), created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, INDEX idx_users_username (username), INDEX idx_users_email (email), INDEX idx_users_oauth (oauth_provider, oauth_id) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci """) # Create projects table cursor.execute(""" CREATE TABLE IF NOT EXISTS projects ( id VARCHAR(36) PRIMARY KEY, name VARCHAR(255) NOT NULL, description TEXT, config TEXT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci """) # Create tasks table cursor.execute(""" CREATE TABLE IF NOT EXISTS tasks ( id VARCHAR(36) PRIMARY KEY, project_id VARCHAR(36) NOT NULL, name VARCHAR(255) NOT NULL, data LONGTEXT NOT NULL, status VARCHAR(50) DEFAULT 'pending', assigned_to VARCHAR(36), created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE, INDEX idx_tasks_project (project_id), INDEX idx_tasks_status (status) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci """) # Create annotations table cursor.execute(""" CREATE TABLE IF NOT EXISTS annotations ( id VARCHAR(36) PRIMARY KEY, task_id VARCHAR(36) NOT NULL, user_id VARCHAR(36) NOT NULL, result LONGTEXT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, FOREIGN KEY (task_id) REFERENCES tasks(id) ON DELETE CASCADE, INDEX idx_annotations_task (task_id), INDEX idx_annotations_user (user_id) ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci """) logger.info("MySQL 数据库初始化完成") def _init_sqlite_database() -> None: """Initialize SQLite database tables.""" with get_db_connection() as conn: cursor = conn.cursor() # Create users table cursor.execute(""" CREATE TABLE IF NOT EXISTS users ( id TEXT PRIMARY KEY, username TEXT NOT NULL UNIQUE, email TEXT NOT NULL UNIQUE, password_hash TEXT NOT NULL, role TEXT NOT NULL DEFAULT 'annotator', oauth_provider TEXT, oauth_id TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_username ON users(username)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)") cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_oauth ON users(oauth_provider, oauth_id)") # Create projects table cursor.execute(""" CREATE TABLE IF NOT EXISTS projects ( id TEXT PRIMARY KEY, name TEXT NOT NULL, description TEXT, config TEXT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ) """) # Create tasks table cursor.execute(""" CREATE TABLE IF NOT EXISTS tasks ( id TEXT PRIMARY KEY, project_id TEXT NOT NULL, name TEXT NOT NULL, data TEXT NOT NULL, status TEXT DEFAULT 'pending', assigned_to TEXT, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE ) """) # Create annotations table cursor.execute(""" CREATE TABLE IF NOT EXISTS annotations ( id TEXT PRIMARY KEY, task_id TEXT NOT NULL, user_id TEXT NOT NULL, result TEXT NOT NULL, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, FOREIGN KEY (task_id) REFERENCES tasks(id) ON DELETE CASCADE ) """) logger.info("SQLite 数据库初始化完成") def get_db(): """ Get a database connection (legacy support). Note: Caller is responsible for closing the connection. """ if db_config.db_type == 'mysql': conn = _get_mysql_connection() else: conn = _get_sqlite_connection() return ConnectionWrapper(conn, db_config.db_type)