| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334 |
- """
- Database connection and initialization module.
- Supports both SQLite and MySQL databases.
- """
- import os
- import logging
- from contextlib import contextmanager
- from typing import Generator, Any, Optional
- from config import settings
- logger = logging.getLogger(__name__)
- class DatabaseConfig:
- """Database configuration holder."""
-
- def __init__(self):
- self.db_type = getattr(settings, 'DATABASE_TYPE', 'sqlite')
-
- if self.db_type == 'mysql':
- self.host = getattr(settings, 'MYSQL_HOST', 'localhost')
- self.port = getattr(settings, 'MYSQL_PORT', 3306)
- self.user = getattr(settings, 'MYSQL_USER', 'root')
- self.password = getattr(settings, 'MYSQL_PASSWORD', '')
- self.database = getattr(settings, 'MYSQL_DATABASE', 'annotation_platform')
- else:
- self.db_path = getattr(settings, 'DATABASE_PATH', 'annotation_platform.db')
- db_config = DatabaseConfig()
- def _get_mysql_connection():
- """Get MySQL connection."""
- import pymysql
- conn = pymysql.connect(
- host=db_config.host,
- port=db_config.port,
- user=db_config.user,
- password=db_config.password,
- database=db_config.database,
- charset='utf8mb4',
- cursorclass=pymysql.cursors.DictCursor,
- autocommit=False
- )
- return conn
- def _get_sqlite_connection():
- """Get SQLite connection."""
- import sqlite3
- conn = sqlite3.connect(db_config.db_path)
- conn.row_factory = sqlite3.Row
- conn.execute("PRAGMA foreign_keys = ON")
- return conn
- class RowWrapper:
- """Wrapper to provide consistent row access for both SQLite and MySQL."""
-
- def __init__(self, row, db_type: str):
- self._row = row
- self._db_type = db_type
-
- def __getitem__(self, key):
- if self._db_type == 'mysql':
- return self._row[key]
- else:
- return self._row[key]
-
- def keys(self):
- if self._db_type == 'mysql':
- return self._row.keys()
- else:
- return self._row.keys()
- class CursorWrapper:
- """Wrapper to provide consistent cursor interface for both databases."""
-
- def __init__(self, cursor, db_type: str):
- self._cursor = cursor
- self._db_type = db_type
-
- def execute(self, sql: str, params: tuple = None):
- """Execute SQL with parameter conversion."""
- if self._db_type == 'mysql':
- # Convert ? placeholders to %s for MySQL
- sql = sql.replace('?', '%s')
-
- if params:
- self._cursor.execute(sql, params)
- else:
- self._cursor.execute(sql)
- return self
-
- def fetchone(self) -> Optional[RowWrapper]:
- row = self._cursor.fetchone()
- if row is None:
- return None
- return RowWrapper(row, self._db_type)
-
- def fetchall(self) -> list:
- rows = self._cursor.fetchall()
- return [RowWrapper(row, self._db_type) for row in rows]
-
- @property
- def lastrowid(self):
- return self._cursor.lastrowid
-
- @property
- def rowcount(self):
- return self._cursor.rowcount
- class ConnectionWrapper:
- """Wrapper to provide consistent connection interface."""
-
- def __init__(self, conn, db_type: str):
- self._conn = conn
- self._db_type = db_type
-
- def cursor(self) -> CursorWrapper:
- return CursorWrapper(self._conn.cursor(), self._db_type)
-
- def commit(self):
- self._conn.commit()
-
- def rollback(self):
- self._conn.rollback()
-
- def close(self):
- self._conn.close()
-
- def execute(self, sql: str, params: tuple = None):
- cursor = self.cursor()
- cursor.execute(sql, params)
- return cursor
- @contextmanager
- def get_db_connection() -> Generator[ConnectionWrapper, None, None]:
- """
- Context manager for database connections.
- Ensures proper connection cleanup.
- """
- if db_config.db_type == 'mysql':
- conn = _get_mysql_connection()
- else:
- conn = _get_sqlite_connection()
-
- wrapped = ConnectionWrapper(conn, db_config.db_type)
-
- try:
- yield wrapped
- wrapped.commit()
- except Exception:
- wrapped.rollback()
- raise
- finally:
- wrapped.close()
- def init_database() -> None:
- """
- Initialize database and create tables if they don't exist.
- """
- if db_config.db_type == 'mysql':
- _init_mysql_database()
- else:
- _init_sqlite_database()
- def _init_mysql_database() -> None:
- """Initialize MySQL database tables."""
- import pymysql
-
- # First, create database if not exists
- conn = pymysql.connect(
- host=db_config.host,
- port=db_config.port,
- user=db_config.user,
- password=db_config.password,
- charset='utf8mb4'
- )
- try:
- with conn.cursor() as cursor:
- cursor.execute(f"CREATE DATABASE IF NOT EXISTS `{db_config.database}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci")
- conn.commit()
- finally:
- conn.close()
-
- # Now create tables
- with get_db_connection() as conn:
- cursor = conn.cursor()
-
- # Create users table
- cursor.execute("""
- CREATE TABLE IF NOT EXISTS users (
- id VARCHAR(36) PRIMARY KEY,
- username VARCHAR(255) NOT NULL UNIQUE,
- email VARCHAR(255) NOT NULL UNIQUE,
- password_hash VARCHAR(255) NOT NULL,
- role VARCHAR(50) NOT NULL DEFAULT 'annotator',
- oauth_provider VARCHAR(50),
- oauth_id VARCHAR(255),
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
- INDEX idx_users_username (username),
- INDEX idx_users_email (email),
- INDEX idx_users_oauth (oauth_provider, oauth_id)
- ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
- """)
-
- # Create projects table
- cursor.execute("""
- CREATE TABLE IF NOT EXISTS projects (
- id VARCHAR(36) PRIMARY KEY,
- name VARCHAR(255) NOT NULL,
- description TEXT,
- config TEXT NOT NULL,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
- ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
- """)
-
- # Create tasks table
- cursor.execute("""
- CREATE TABLE IF NOT EXISTS tasks (
- id VARCHAR(36) PRIMARY KEY,
- project_id VARCHAR(36) NOT NULL,
- name VARCHAR(255) NOT NULL,
- data LONGTEXT NOT NULL,
- status VARCHAR(50) DEFAULT 'pending',
- assigned_to VARCHAR(36),
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE,
- INDEX idx_tasks_project (project_id),
- INDEX idx_tasks_status (status)
- ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
- """)
-
- # Create annotations table
- cursor.execute("""
- CREATE TABLE IF NOT EXISTS annotations (
- id VARCHAR(36) PRIMARY KEY,
- task_id VARCHAR(36) NOT NULL,
- user_id VARCHAR(36) NOT NULL,
- result LONGTEXT NOT NULL,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
- FOREIGN KEY (task_id) REFERENCES tasks(id) ON DELETE CASCADE,
- INDEX idx_annotations_task (task_id),
- INDEX idx_annotations_user (user_id)
- ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
- """)
-
- logger.info("MySQL 数据库初始化完成")
- def _init_sqlite_database() -> None:
- """Initialize SQLite database tables."""
- with get_db_connection() as conn:
- cursor = conn.cursor()
-
- # Create users table
- cursor.execute("""
- CREATE TABLE IF NOT EXISTS users (
- id TEXT PRIMARY KEY,
- username TEXT NOT NULL UNIQUE,
- email TEXT NOT NULL UNIQUE,
- password_hash TEXT NOT NULL,
- role TEXT NOT NULL DEFAULT 'annotator',
- oauth_provider TEXT,
- oauth_id TEXT,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
- )
- """)
-
- cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_username ON users(username)")
- cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)")
- cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_oauth ON users(oauth_provider, oauth_id)")
-
- # Create projects table
- cursor.execute("""
- CREATE TABLE IF NOT EXISTS projects (
- id TEXT PRIMARY KEY,
- name TEXT NOT NULL,
- description TEXT,
- config TEXT NOT NULL,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
- )
- """)
-
- # Create tasks table
- cursor.execute("""
- CREATE TABLE IF NOT EXISTS tasks (
- id TEXT PRIMARY KEY,
- project_id TEXT NOT NULL,
- name TEXT NOT NULL,
- data TEXT NOT NULL,
- status TEXT DEFAULT 'pending',
- assigned_to TEXT,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE
- )
- """)
-
- # Create annotations table
- cursor.execute("""
- CREATE TABLE IF NOT EXISTS annotations (
- id TEXT PRIMARY KEY,
- task_id TEXT NOT NULL,
- user_id TEXT NOT NULL,
- result TEXT NOT NULL,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- FOREIGN KEY (task_id) REFERENCES tasks(id) ON DELETE CASCADE
- )
- """)
-
- logger.info("SQLite 数据库初始化完成")
- def get_db():
- """
- Get a database connection (legacy support).
- Note: Caller is responsible for closing the connection.
- """
- if db_config.db_type == 'mysql':
- conn = _get_mysql_connection()
- else:
- conn = _get_sqlite_connection()
- return ConnectionWrapper(conn, db_config.db_type)
|