database.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. """
  2. Database connection and initialization module.
  3. MySQL only - SQLite support removed.
  4. """
  5. import logging
  6. from contextlib import contextmanager
  7. from typing import Generator, Optional
  8. from config import settings
  9. import pymysql
  10. logger = logging.getLogger(__name__)
  11. class DatabaseConfig:
  12. """Database configuration holder (MySQL only)."""
  13. def __init__(self):
  14. self.host = getattr(settings, 'MYSQL_HOST', 'localhost')
  15. self.port = getattr(settings, 'MYSQL_PORT', 3306)
  16. self.user = getattr(settings, 'MYSQL_USER', 'root')
  17. self.password = getattr(settings, 'MYSQL_PASSWORD', '')
  18. self.database = getattr(settings, 'MYSQL_DATABASE', 'annotation_platform')
  19. db_config = DatabaseConfig()
  20. def _get_mysql_connection():
  21. """Get MySQL connection."""
  22. conn = pymysql.connect(
  23. host=db_config.host,
  24. port=db_config.port,
  25. user=db_config.user,
  26. password=db_config.password,
  27. database=db_config.database,
  28. charset='utf8mb4',
  29. cursorclass=pymysql.cursors.DictCursor,
  30. autocommit=False
  31. )
  32. return conn
  33. class RowWrapper:
  34. """Wrapper to provide consistent row access."""
  35. def __init__(self, row):
  36. self._row = row
  37. def __getitem__(self, key):
  38. return self._row[key]
  39. def keys(self):
  40. return self._row.keys()
  41. class CursorWrapper:
  42. """Wrapper to provide consistent cursor interface."""
  43. def __init__(self, cursor):
  44. self._cursor = cursor
  45. def execute(self, sql: str, params: tuple = None):
  46. """Execute SQL with parameter conversion."""
  47. # Convert ? placeholders to %s for MySQL
  48. sql = sql.replace('?', '%s')
  49. if params:
  50. self._cursor.execute(sql, params)
  51. else:
  52. self._cursor.execute(sql)
  53. return self
  54. def fetchone(self) -> Optional[RowWrapper]:
  55. row = self._cursor.fetchone()
  56. if row is None:
  57. return None
  58. return RowWrapper(row)
  59. def fetchall(self) -> list:
  60. rows = self._cursor.fetchall()
  61. return [RowWrapper(row) for row in rows]
  62. @property
  63. def lastrowid(self):
  64. return self._cursor.lastrowid
  65. @property
  66. def rowcount(self):
  67. return self._cursor.rowcount
  68. class ConnectionWrapper:
  69. """Wrapper to provide consistent connection interface."""
  70. def __init__(self, conn):
  71. self._conn = conn
  72. def cursor(self) -> CursorWrapper:
  73. return CursorWrapper(self._conn.cursor())
  74. def commit(self):
  75. self._conn.commit()
  76. def rollback(self):
  77. self._conn.rollback()
  78. def close(self):
  79. self._conn.close()
  80. def execute(self, sql: str, params: tuple = None):
  81. cursor = self.cursor()
  82. cursor.execute(sql, params)
  83. return cursor
  84. @contextmanager
  85. def get_db_connection() -> Generator[ConnectionWrapper, None, None]:
  86. """
  87. Context manager for database connections.
  88. Ensures proper connection cleanup.
  89. """
  90. conn = _get_mysql_connection()
  91. wrapped = ConnectionWrapper(conn)
  92. try:
  93. yield wrapped
  94. wrapped.commit()
  95. except Exception:
  96. wrapped.rollback()
  97. raise
  98. finally:
  99. wrapped.close()
  100. def init_database() -> None:
  101. """Initialize MySQL database and create tables if they don't exist."""
  102. # First, create database if not exists
  103. conn = pymysql.connect(
  104. host=db_config.host,
  105. port=db_config.port,
  106. user=db_config.user,
  107. password=db_config.password,
  108. charset='utf8mb4'
  109. )
  110. try:
  111. with conn.cursor() as cursor:
  112. cursor.execute(f"CREATE DATABASE IF NOT EXISTS `{db_config.database}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci")
  113. conn.commit()
  114. finally:
  115. conn.close()
  116. # Now create tables
  117. with get_db_connection() as conn:
  118. cursor = conn.cursor()
  119. # Create users table
  120. cursor.execute("""
  121. CREATE TABLE IF NOT EXISTS users (
  122. id VARCHAR(36) PRIMARY KEY,
  123. username VARCHAR(255) NOT NULL UNIQUE,
  124. email VARCHAR(255) NOT NULL UNIQUE,
  125. password_hash VARCHAR(255) NOT NULL,
  126. role VARCHAR(50) NOT NULL DEFAULT 'annotator',
  127. oauth_provider VARCHAR(50),
  128. oauth_id VARCHAR(255),
  129. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  130. updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
  131. INDEX idx_users_username (username),
  132. INDEX idx_users_email (email),
  133. INDEX idx_users_oauth (oauth_provider, oauth_id)
  134. ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
  135. """)
  136. # Create projects table
  137. cursor.execute("""
  138. CREATE TABLE IF NOT EXISTS projects (
  139. id VARCHAR(36) PRIMARY KEY,
  140. name VARCHAR(255) NOT NULL,
  141. description TEXT,
  142. config TEXT NOT NULL,
  143. status VARCHAR(20) DEFAULT 'draft',
  144. source VARCHAR(20) DEFAULT 'internal',
  145. task_type VARCHAR(50),
  146. external_id VARCHAR(100),
  147. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  148. updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
  149. INDEX idx_projects_status (status),
  150. INDEX idx_projects_source (source),
  151. INDEX idx_projects_external_id (external_id)
  152. ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
  153. """)
  154. # Create tasks table
  155. cursor.execute("""
  156. CREATE TABLE IF NOT EXISTS tasks (
  157. id VARCHAR(36) PRIMARY KEY,
  158. project_id VARCHAR(36) NOT NULL,
  159. name VARCHAR(255) NOT NULL,
  160. data LONGTEXT NOT NULL,
  161. status VARCHAR(50) DEFAULT 'pending',
  162. assigned_to VARCHAR(36),
  163. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  164. FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE,
  165. INDEX idx_tasks_project (project_id),
  166. INDEX idx_tasks_status (status)
  167. ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
  168. """)
  169. # Create annotations table
  170. cursor.execute("""
  171. CREATE TABLE IF NOT EXISTS annotations (
  172. id VARCHAR(36) PRIMARY KEY,
  173. task_id VARCHAR(36) NOT NULL,
  174. user_id VARCHAR(36) NOT NULL,
  175. result LONGTEXT NOT NULL,
  176. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  177. updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
  178. FOREIGN KEY (task_id) REFERENCES tasks(id) ON DELETE CASCADE,
  179. INDEX idx_annotations_task (task_id),
  180. INDEX idx_annotations_user (user_id)
  181. ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
  182. """)
  183. # Create export_jobs table
  184. cursor.execute("""
  185. CREATE TABLE IF NOT EXISTS export_jobs (
  186. id VARCHAR(36) PRIMARY KEY,
  187. project_id VARCHAR(36) NOT NULL,
  188. format VARCHAR(50) NOT NULL,
  189. status VARCHAR(50) DEFAULT 'pending',
  190. status_filter VARCHAR(50) DEFAULT 'all',
  191. include_metadata BOOLEAN DEFAULT TRUE,
  192. file_path TEXT,
  193. error_message TEXT,
  194. created_by VARCHAR(36),
  195. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  196. completed_at TIMESTAMP,
  197. total_tasks INT DEFAULT 0,
  198. exported_tasks INT DEFAULT 0,
  199. FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE,
  200. INDEX idx_export_jobs_project (project_id),
  201. INDEX idx_export_jobs_status (status)
  202. ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
  203. """)
  204. logger.info("MySQL 数据库初始化完成")
  205. def get_db():
  206. """
  207. Get a database connection (legacy support).
  208. Note: Caller is responsible for closing the connection.
  209. """
  210. conn = _get_mysql_connection()
  211. return ConnectionWrapper(conn)