|
|
@@ -1,30 +1,26 @@
|
|
|
"""
|
|
|
Database connection and initialization module.
|
|
|
-Supports both SQLite and MySQL databases.
|
|
|
+MySQL only - SQLite support removed.
|
|
|
"""
|
|
|
-import os
|
|
|
import logging
|
|
|
from contextlib import contextmanager
|
|
|
-from typing import Generator, Any, Optional
|
|
|
+from typing import Generator, Optional
|
|
|
from config import settings
|
|
|
|
|
|
+import pymysql
|
|
|
+
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
class DatabaseConfig:
|
|
|
- """Database configuration holder."""
|
|
|
+ """Database configuration holder (MySQL only)."""
|
|
|
|
|
|
def __init__(self):
|
|
|
- self.db_type = getattr(settings, 'DATABASE_TYPE', 'sqlite')
|
|
|
-
|
|
|
- if self.db_type == 'mysql':
|
|
|
- self.host = getattr(settings, 'MYSQL_HOST', 'localhost')
|
|
|
- self.port = getattr(settings, 'MYSQL_PORT', 3306)
|
|
|
- self.user = getattr(settings, 'MYSQL_USER', 'root')
|
|
|
- self.password = getattr(settings, 'MYSQL_PASSWORD', '')
|
|
|
- self.database = getattr(settings, 'MYSQL_DATABASE', 'annotation_platform')
|
|
|
- else:
|
|
|
- self.db_path = getattr(settings, 'DATABASE_PATH', 'annotation_platform.db')
|
|
|
+ self.host = getattr(settings, 'MYSQL_HOST', 'localhost')
|
|
|
+ self.port = getattr(settings, 'MYSQL_PORT', 3306)
|
|
|
+ self.user = getattr(settings, 'MYSQL_USER', 'root')
|
|
|
+ self.password = getattr(settings, 'MYSQL_PASSWORD', '')
|
|
|
+ self.database = getattr(settings, 'MYSQL_DATABASE', 'annotation_platform')
|
|
|
|
|
|
|
|
|
db_config = DatabaseConfig()
|
|
|
@@ -32,7 +28,6 @@ db_config = DatabaseConfig()
|
|
|
|
|
|
def _get_mysql_connection():
|
|
|
"""Get MySQL connection."""
|
|
|
- import pymysql
|
|
|
conn = pymysql.connect(
|
|
|
host=db_config.host,
|
|
|
port=db_config.port,
|
|
|
@@ -46,47 +41,29 @@ def _get_mysql_connection():
|
|
|
return conn
|
|
|
|
|
|
|
|
|
-def _get_sqlite_connection():
|
|
|
- """Get SQLite connection."""
|
|
|
- import sqlite3
|
|
|
- conn = sqlite3.connect(db_config.db_path)
|
|
|
- conn.row_factory = sqlite3.Row
|
|
|
- conn.execute("PRAGMA foreign_keys = ON")
|
|
|
- return conn
|
|
|
-
|
|
|
-
|
|
|
class RowWrapper:
|
|
|
- """Wrapper to provide consistent row access for both SQLite and MySQL."""
|
|
|
+ """Wrapper to provide consistent row access."""
|
|
|
|
|
|
- def __init__(self, row, db_type: str):
|
|
|
+ def __init__(self, row):
|
|
|
self._row = row
|
|
|
- self._db_type = db_type
|
|
|
|
|
|
def __getitem__(self, key):
|
|
|
- if self._db_type == 'mysql':
|
|
|
- return self._row[key]
|
|
|
- else:
|
|
|
- return self._row[key]
|
|
|
+ return self._row[key]
|
|
|
|
|
|
def keys(self):
|
|
|
- if self._db_type == 'mysql':
|
|
|
- return self._row.keys()
|
|
|
- else:
|
|
|
- return self._row.keys()
|
|
|
+ return self._row.keys()
|
|
|
|
|
|
|
|
|
class CursorWrapper:
|
|
|
- """Wrapper to provide consistent cursor interface for both databases."""
|
|
|
+ """Wrapper to provide consistent cursor interface."""
|
|
|
|
|
|
- def __init__(self, cursor, db_type: str):
|
|
|
+ def __init__(self, cursor):
|
|
|
self._cursor = cursor
|
|
|
- self._db_type = db_type
|
|
|
|
|
|
def execute(self, sql: str, params: tuple = None):
|
|
|
"""Execute SQL with parameter conversion."""
|
|
|
- if self._db_type == 'mysql':
|
|
|
- # Convert ? placeholders to %s for MySQL
|
|
|
- sql = sql.replace('?', '%s')
|
|
|
+ # Convert ? placeholders to %s for MySQL
|
|
|
+ sql = sql.replace('?', '%s')
|
|
|
|
|
|
if params:
|
|
|
self._cursor.execute(sql, params)
|
|
|
@@ -98,11 +75,11 @@ class CursorWrapper:
|
|
|
row = self._cursor.fetchone()
|
|
|
if row is None:
|
|
|
return None
|
|
|
- return RowWrapper(row, self._db_type)
|
|
|
+ return RowWrapper(row)
|
|
|
|
|
|
def fetchall(self) -> list:
|
|
|
rows = self._cursor.fetchall()
|
|
|
- return [RowWrapper(row, self._db_type) for row in rows]
|
|
|
+ return [RowWrapper(row) for row in rows]
|
|
|
|
|
|
@property
|
|
|
def lastrowid(self):
|
|
|
@@ -116,12 +93,11 @@ class CursorWrapper:
|
|
|
class ConnectionWrapper:
|
|
|
"""Wrapper to provide consistent connection interface."""
|
|
|
|
|
|
- def __init__(self, conn, db_type: str):
|
|
|
+ def __init__(self, conn):
|
|
|
self._conn = conn
|
|
|
- self._db_type = db_type
|
|
|
|
|
|
def cursor(self) -> CursorWrapper:
|
|
|
- return CursorWrapper(self._conn.cursor(), self._db_type)
|
|
|
+ return CursorWrapper(self._conn.cursor())
|
|
|
|
|
|
def commit(self):
|
|
|
self._conn.commit()
|
|
|
@@ -144,12 +120,8 @@ def get_db_connection() -> Generator[ConnectionWrapper, None, None]:
|
|
|
Context manager for database connections.
|
|
|
Ensures proper connection cleanup.
|
|
|
"""
|
|
|
- if db_config.db_type == 'mysql':
|
|
|
- conn = _get_mysql_connection()
|
|
|
- else:
|
|
|
- conn = _get_sqlite_connection()
|
|
|
-
|
|
|
- wrapped = ConnectionWrapper(conn, db_config.db_type)
|
|
|
+ conn = _get_mysql_connection()
|
|
|
+ wrapped = ConnectionWrapper(conn)
|
|
|
|
|
|
try:
|
|
|
yield wrapped
|
|
|
@@ -162,19 +134,7 @@ def get_db_connection() -> Generator[ConnectionWrapper, None, None]:
|
|
|
|
|
|
|
|
|
def init_database() -> None:
|
|
|
- """
|
|
|
- Initialize database and create tables if they don't exist.
|
|
|
- """
|
|
|
- if db_config.db_type == 'mysql':
|
|
|
- _init_mysql_database()
|
|
|
- else:
|
|
|
- _init_sqlite_database()
|
|
|
-
|
|
|
-
|
|
|
-def _init_mysql_database() -> None:
|
|
|
- """Initialize MySQL database tables."""
|
|
|
- import pymysql
|
|
|
-
|
|
|
+ """Initialize MySQL database and create tables if they don't exist."""
|
|
|
# First, create database if not exists
|
|
|
conn = pymysql.connect(
|
|
|
host=db_config.host,
|
|
|
@@ -287,110 +247,10 @@ def _init_mysql_database() -> None:
|
|
|
logger.info("MySQL 数据库初始化完成")
|
|
|
|
|
|
|
|
|
-def _init_sqlite_database() -> None:
|
|
|
- """Initialize SQLite database tables."""
|
|
|
- with get_db_connection() as conn:
|
|
|
- cursor = conn.cursor()
|
|
|
-
|
|
|
- # Create users table
|
|
|
- cursor.execute("""
|
|
|
- CREATE TABLE IF NOT EXISTS users (
|
|
|
- id TEXT PRIMARY KEY,
|
|
|
- username TEXT NOT NULL UNIQUE,
|
|
|
- email TEXT NOT NULL UNIQUE,
|
|
|
- password_hash TEXT NOT NULL,
|
|
|
- role TEXT NOT NULL DEFAULT 'annotator',
|
|
|
- oauth_provider TEXT,
|
|
|
- oauth_id TEXT,
|
|
|
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
|
- updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
|
|
- )
|
|
|
- """)
|
|
|
-
|
|
|
- cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_username ON users(username)")
|
|
|
- cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)")
|
|
|
- cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_oauth ON users(oauth_provider, oauth_id)")
|
|
|
-
|
|
|
- # Create projects table
|
|
|
- cursor.execute("""
|
|
|
- CREATE TABLE IF NOT EXISTS projects (
|
|
|
- id TEXT PRIMARY KEY,
|
|
|
- name TEXT NOT NULL,
|
|
|
- description TEXT,
|
|
|
- config TEXT NOT NULL,
|
|
|
- status TEXT DEFAULT 'draft',
|
|
|
- source TEXT DEFAULT 'internal',
|
|
|
- task_type TEXT,
|
|
|
- external_id TEXT,
|
|
|
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
|
- updated_at TIMESTAMP
|
|
|
- )
|
|
|
- """)
|
|
|
-
|
|
|
- cursor.execute("CREATE INDEX IF NOT EXISTS idx_projects_status ON projects(status)")
|
|
|
- cursor.execute("CREATE INDEX IF NOT EXISTS idx_projects_source ON projects(source)")
|
|
|
- cursor.execute("CREATE INDEX IF NOT EXISTS idx_projects_external_id ON projects(external_id)")
|
|
|
-
|
|
|
- # Create tasks table
|
|
|
- cursor.execute("""
|
|
|
- CREATE TABLE IF NOT EXISTS tasks (
|
|
|
- id TEXT PRIMARY KEY,
|
|
|
- project_id TEXT NOT NULL,
|
|
|
- name TEXT NOT NULL,
|
|
|
- data TEXT NOT NULL,
|
|
|
- status TEXT DEFAULT 'pending',
|
|
|
- assigned_to TEXT,
|
|
|
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
|
- FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE
|
|
|
- )
|
|
|
- """)
|
|
|
-
|
|
|
- # Create annotations table
|
|
|
- cursor.execute("""
|
|
|
- CREATE TABLE IF NOT EXISTS annotations (
|
|
|
- id TEXT PRIMARY KEY,
|
|
|
- task_id TEXT NOT NULL,
|
|
|
- user_id TEXT NOT NULL,
|
|
|
- result TEXT NOT NULL,
|
|
|
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
|
- updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
|
- FOREIGN KEY (task_id) REFERENCES tasks(id) ON DELETE CASCADE
|
|
|
- )
|
|
|
- """)
|
|
|
-
|
|
|
- # Create export_jobs table
|
|
|
- cursor.execute("""
|
|
|
- CREATE TABLE IF NOT EXISTS export_jobs (
|
|
|
- id TEXT PRIMARY KEY,
|
|
|
- project_id TEXT NOT NULL,
|
|
|
- format TEXT NOT NULL,
|
|
|
- status TEXT DEFAULT 'pending',
|
|
|
- status_filter TEXT DEFAULT 'all',
|
|
|
- include_metadata INTEGER DEFAULT 1,
|
|
|
- file_path TEXT,
|
|
|
- error_message TEXT,
|
|
|
- created_by TEXT,
|
|
|
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
|
- completed_at TIMESTAMP,
|
|
|
- total_tasks INTEGER DEFAULT 0,
|
|
|
- exported_tasks INTEGER DEFAULT 0,
|
|
|
- FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE
|
|
|
- )
|
|
|
- """)
|
|
|
-
|
|
|
- cursor.execute("CREATE INDEX IF NOT EXISTS idx_export_jobs_project ON export_jobs(project_id)")
|
|
|
- cursor.execute("CREATE INDEX IF NOT EXISTS idx_export_jobs_status ON export_jobs(status)")
|
|
|
-
|
|
|
- logger.info("SQLite 数据库初始化完成")
|
|
|
-
|
|
|
-
|
|
|
def get_db():
|
|
|
"""
|
|
|
Get a database connection (legacy support).
|
|
|
Note: Caller is responsible for closing the connection.
|
|
|
"""
|
|
|
- if db_config.db_type == 'mysql':
|
|
|
- conn = _get_mysql_connection()
|
|
|
- else:
|
|
|
- conn = _get_sqlite_connection()
|
|
|
- return ConnectionWrapper(conn, db_config.db_type)
|
|
|
+ conn = _get_mysql_connection()
|
|
|
+ return ConnectionWrapper(conn)
|