Ver Fonte

-dev:修复数据库问题

LuoChinWen há 3 semanas atrás
pai
commit
f7119cf009
7 ficheiros alterados com 36 adições e 184 exclusões
  1. 2 1
      backend/Dockerfile
  2. 1 4
      backend/config.dev.yaml
  3. 1 4
      backend/config.prod.yaml
  4. 1 3
      backend/config.py
  5. 27 167
      backend/database.py
  6. 2 3
      backend/docker-compose.yml
  7. 2 2
      deploy.sh

+ 2 - 1
backend/Dockerfile

@@ -21,7 +21,8 @@ RUN pip install --no-cache-dir -r requirements.txt
 # 再复制应用代码(代码变更不会重新安装依赖)
 COPY . .
 
-RUN mkdir -p /app/data
+# 创建导出目录
+RUN mkdir -p /app/exports
 
 EXPOSE 8000
 

+ 1 - 4
backend/config.dev.yaml

@@ -22,11 +22,8 @@ oauth:
   userinfo_endpoint: "/oauth/userinfo"
   revoke_endpoint: "/oauth/revoke"
 
-# 数据库配置
+# 数据库配置 (MySQL)
 database:
-  type: "mysql"
-  path: "annotation_platform.db"
-  
   mysql:
     host: "192.168.92.61"
     port: 13306

+ 1 - 4
backend/config.prod.yaml

@@ -23,11 +23,8 @@ oauth:
   userinfo_endpoint: "/oauth/userinfo"
   revoke_endpoint: "/oauth/revoke"
 
-# 数据库配置
+# 数据库配置 (MySQL)
 database:
-  type: "mysql"
-  path: "/app/data/annotation_platform.db"
-  
   mysql:
     host: "192.168.92.61"
     port: 13306

+ 1 - 3
backend/config.py

@@ -63,10 +63,8 @@ class Settings:
         self.ACCESS_TOKEN_EXPIRE_MINUTES = jwt_config.get('access_token_expire_minutes', 15)
         self.REFRESH_TOKEN_EXPIRE_DAYS = jwt_config.get('refresh_token_expire_days', 7)
         
-        # Database Settings
+        # Database Settings (MySQL only)
         db_config = config.get('database', {})
-        self.DATABASE_TYPE = db_config.get('type', 'sqlite')
-        self.DATABASE_PATH = db_config.get('path', 'annotation_platform.db')
         
         # MySQL Settings
         mysql_config = db_config.get('mysql', {})

+ 27 - 167
backend/database.py

@@ -1,30 +1,26 @@
 """
 Database connection and initialization module.
-Supports both SQLite and MySQL databases.
+MySQL only - SQLite support removed.
 """
-import os
 import logging
 from contextlib import contextmanager
-from typing import Generator, Any, Optional
+from typing import Generator, Optional
 from config import settings
 
+import pymysql
+
 logger = logging.getLogger(__name__)
 
 
 class DatabaseConfig:
-    """Database configuration holder."""
+    """Database configuration holder (MySQL only)."""
     
     def __init__(self):
-        self.db_type = getattr(settings, 'DATABASE_TYPE', 'sqlite')
-        
-        if self.db_type == 'mysql':
-            self.host = getattr(settings, 'MYSQL_HOST', 'localhost')
-            self.port = getattr(settings, 'MYSQL_PORT', 3306)
-            self.user = getattr(settings, 'MYSQL_USER', 'root')
-            self.password = getattr(settings, 'MYSQL_PASSWORD', '')
-            self.database = getattr(settings, 'MYSQL_DATABASE', 'annotation_platform')
-        else:
-            self.db_path = getattr(settings, 'DATABASE_PATH', 'annotation_platform.db')
+        self.host = getattr(settings, 'MYSQL_HOST', 'localhost')
+        self.port = getattr(settings, 'MYSQL_PORT', 3306)
+        self.user = getattr(settings, 'MYSQL_USER', 'root')
+        self.password = getattr(settings, 'MYSQL_PASSWORD', '')
+        self.database = getattr(settings, 'MYSQL_DATABASE', 'annotation_platform')
 
 
 db_config = DatabaseConfig()
@@ -32,7 +28,6 @@ db_config = DatabaseConfig()
 
 def _get_mysql_connection():
     """Get MySQL connection."""
-    import pymysql
     conn = pymysql.connect(
         host=db_config.host,
         port=db_config.port,
@@ -46,47 +41,29 @@ def _get_mysql_connection():
     return conn
 
 
-def _get_sqlite_connection():
-    """Get SQLite connection."""
-    import sqlite3
-    conn = sqlite3.connect(db_config.db_path)
-    conn.row_factory = sqlite3.Row
-    conn.execute("PRAGMA foreign_keys = ON")
-    return conn
-
-
 class RowWrapper:
-    """Wrapper to provide consistent row access for both SQLite and MySQL."""
+    """Wrapper to provide consistent row access."""
     
-    def __init__(self, row, db_type: str):
+    def __init__(self, row):
         self._row = row
-        self._db_type = db_type
     
     def __getitem__(self, key):
-        if self._db_type == 'mysql':
-            return self._row[key]
-        else:
-            return self._row[key]
+        return self._row[key]
     
     def keys(self):
-        if self._db_type == 'mysql':
-            return self._row.keys()
-        else:
-            return self._row.keys()
+        return self._row.keys()
 
 
 class CursorWrapper:
-    """Wrapper to provide consistent cursor interface for both databases."""
+    """Wrapper to provide consistent cursor interface."""
     
-    def __init__(self, cursor, db_type: str):
+    def __init__(self, cursor):
         self._cursor = cursor
-        self._db_type = db_type
     
     def execute(self, sql: str, params: tuple = None):
         """Execute SQL with parameter conversion."""
-        if self._db_type == 'mysql':
-            # Convert ? placeholders to %s for MySQL
-            sql = sql.replace('?', '%s')
+        # Convert ? placeholders to %s for MySQL
+        sql = sql.replace('?', '%s')
         
         if params:
             self._cursor.execute(sql, params)
@@ -98,11 +75,11 @@ class CursorWrapper:
         row = self._cursor.fetchone()
         if row is None:
             return None
-        return RowWrapper(row, self._db_type)
+        return RowWrapper(row)
     
     def fetchall(self) -> list:
         rows = self._cursor.fetchall()
-        return [RowWrapper(row, self._db_type) for row in rows]
+        return [RowWrapper(row) for row in rows]
     
     @property
     def lastrowid(self):
@@ -116,12 +93,11 @@ class CursorWrapper:
 class ConnectionWrapper:
     """Wrapper to provide consistent connection interface."""
     
-    def __init__(self, conn, db_type: str):
+    def __init__(self, conn):
         self._conn = conn
-        self._db_type = db_type
     
     def cursor(self) -> CursorWrapper:
-        return CursorWrapper(self._conn.cursor(), self._db_type)
+        return CursorWrapper(self._conn.cursor())
     
     def commit(self):
         self._conn.commit()
@@ -144,12 +120,8 @@ def get_db_connection() -> Generator[ConnectionWrapper, None, None]:
     Context manager for database connections.
     Ensures proper connection cleanup.
     """
-    if db_config.db_type == 'mysql':
-        conn = _get_mysql_connection()
-    else:
-        conn = _get_sqlite_connection()
-    
-    wrapped = ConnectionWrapper(conn, db_config.db_type)
+    conn = _get_mysql_connection()
+    wrapped = ConnectionWrapper(conn)
     
     try:
         yield wrapped
@@ -162,19 +134,7 @@ def get_db_connection() -> Generator[ConnectionWrapper, None, None]:
 
 
 def init_database() -> None:
-    """
-    Initialize database and create tables if they don't exist.
-    """
-    if db_config.db_type == 'mysql':
-        _init_mysql_database()
-    else:
-        _init_sqlite_database()
-
-
-def _init_mysql_database() -> None:
-    """Initialize MySQL database tables."""
-    import pymysql
-    
+    """Initialize MySQL database and create tables if they don't exist."""
     # First, create database if not exists
     conn = pymysql.connect(
         host=db_config.host,
@@ -287,110 +247,10 @@ def _init_mysql_database() -> None:
         logger.info("MySQL 数据库初始化完成")
 
 
-def _init_sqlite_database() -> None:
-    """Initialize SQLite database tables."""
-    with get_db_connection() as conn:
-        cursor = conn.cursor()
-        
-        # Create users table
-        cursor.execute("""
-            CREATE TABLE IF NOT EXISTS users (
-                id TEXT PRIMARY KEY,
-                username TEXT NOT NULL UNIQUE,
-                email TEXT NOT NULL UNIQUE,
-                password_hash TEXT NOT NULL,
-                role TEXT NOT NULL DEFAULT 'annotator',
-                oauth_provider TEXT,
-                oauth_id TEXT,
-                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
-                updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
-            )
-        """)
-        
-        cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_username ON users(username)")
-        cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)")
-        cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_oauth ON users(oauth_provider, oauth_id)")
-        
-        # Create projects table
-        cursor.execute("""
-            CREATE TABLE IF NOT EXISTS projects (
-                id TEXT PRIMARY KEY,
-                name TEXT NOT NULL,
-                description TEXT,
-                config TEXT NOT NULL,
-                status TEXT DEFAULT 'draft',
-                source TEXT DEFAULT 'internal',
-                task_type TEXT,
-                external_id TEXT,
-                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
-                updated_at TIMESTAMP
-            )
-        """)
-        
-        cursor.execute("CREATE INDEX IF NOT EXISTS idx_projects_status ON projects(status)")
-        cursor.execute("CREATE INDEX IF NOT EXISTS idx_projects_source ON projects(source)")
-        cursor.execute("CREATE INDEX IF NOT EXISTS idx_projects_external_id ON projects(external_id)")
-        
-        # Create tasks table
-        cursor.execute("""
-            CREATE TABLE IF NOT EXISTS tasks (
-                id TEXT PRIMARY KEY,
-                project_id TEXT NOT NULL,
-                name TEXT NOT NULL,
-                data TEXT NOT NULL,
-                status TEXT DEFAULT 'pending',
-                assigned_to TEXT,
-                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
-                FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE
-            )
-        """)
-        
-        # Create annotations table
-        cursor.execute("""
-            CREATE TABLE IF NOT EXISTS annotations (
-                id TEXT PRIMARY KEY,
-                task_id TEXT NOT NULL,
-                user_id TEXT NOT NULL,
-                result TEXT NOT NULL,
-                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
-                updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
-                FOREIGN KEY (task_id) REFERENCES tasks(id) ON DELETE CASCADE
-            )
-        """)
-        
-        # Create export_jobs table
-        cursor.execute("""
-            CREATE TABLE IF NOT EXISTS export_jobs (
-                id TEXT PRIMARY KEY,
-                project_id TEXT NOT NULL,
-                format TEXT NOT NULL,
-                status TEXT DEFAULT 'pending',
-                status_filter TEXT DEFAULT 'all',
-                include_metadata INTEGER DEFAULT 1,
-                file_path TEXT,
-                error_message TEXT,
-                created_by TEXT,
-                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
-                completed_at TIMESTAMP,
-                total_tasks INTEGER DEFAULT 0,
-                exported_tasks INTEGER DEFAULT 0,
-                FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE
-            )
-        """)
-        
-        cursor.execute("CREATE INDEX IF NOT EXISTS idx_export_jobs_project ON export_jobs(project_id)")
-        cursor.execute("CREATE INDEX IF NOT EXISTS idx_export_jobs_status ON export_jobs(status)")
-        
-        logger.info("SQLite 数据库初始化完成")
-
-
 def get_db():
     """
     Get a database connection (legacy support).
     Note: Caller is responsible for closing the connection.
     """
-    if db_config.db_type == 'mysql':
-        conn = _get_mysql_connection()
-    else:
-        conn = _get_sqlite_connection()
-    return ConnectionWrapper(conn, db_config.db_type)
+    conn = _get_mysql_connection()
+    return ConnectionWrapper(conn)

+ 2 - 3
backend/docker-compose.yml

@@ -12,13 +12,12 @@ services:
     volumes:
       # 挂载代码目录(代码更新无需重新构建镜像)
       - .:/app
-      # 持久化数据库文件
-      - ./data:/app/data
+      # 挂载导出文件目录
+      - ./exports:/app/exports
       # 挂载生产环境配置文件
       - ./config.prod.yaml:/app/config.prod.yaml:ro
     environment:
       - APP_ENV=prod
-      - DATABASE_PATH=/app/data/annotation_platform.db
     restart: unless-stopped
     healthcheck:
       test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')"]

+ 2 - 2
deploy.sh

@@ -55,8 +55,8 @@ if [ ! -f "config.yaml" ]; then
     echo -e "${GREEN}已使用生产环境配置${NC}"
 fi
 
-# 创建数据目录
-mkdir -p data
+# 创建导出目录
+mkdir -p exports
 
 # 检查是否需要重新构建镜像(只有依赖变更才需要)
 NEED_BUILD=false