database.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. """
  2. Database connection and initialization module.
  3. Supports both SQLite and MySQL databases.
  4. """
  5. import os
  6. import logging
  7. from contextlib import contextmanager
  8. from typing import Generator, Any, Optional
  9. from config import settings
  10. logger = logging.getLogger(__name__)
  11. class DatabaseConfig:
  12. """Database configuration holder."""
  13. def __init__(self):
  14. self.db_type = getattr(settings, 'DATABASE_TYPE', 'sqlite')
  15. if self.db_type == 'mysql':
  16. self.host = getattr(settings, 'MYSQL_HOST', 'localhost')
  17. self.port = getattr(settings, 'MYSQL_PORT', 3306)
  18. self.user = getattr(settings, 'MYSQL_USER', 'root')
  19. self.password = getattr(settings, 'MYSQL_PASSWORD', '')
  20. self.database = getattr(settings, 'MYSQL_DATABASE', 'annotation_platform')
  21. else:
  22. self.db_path = getattr(settings, 'DATABASE_PATH', 'annotation_platform.db')
  23. db_config = DatabaseConfig()
  24. def _get_mysql_connection():
  25. """Get MySQL connection."""
  26. import pymysql
  27. conn = pymysql.connect(
  28. host=db_config.host,
  29. port=db_config.port,
  30. user=db_config.user,
  31. password=db_config.password,
  32. database=db_config.database,
  33. charset='utf8mb4',
  34. cursorclass=pymysql.cursors.DictCursor,
  35. autocommit=False
  36. )
  37. return conn
  38. def _get_sqlite_connection():
  39. """Get SQLite connection."""
  40. import sqlite3
  41. conn = sqlite3.connect(db_config.db_path)
  42. conn.row_factory = sqlite3.Row
  43. conn.execute("PRAGMA foreign_keys = ON")
  44. return conn
  45. class RowWrapper:
  46. """Wrapper to provide consistent row access for both SQLite and MySQL."""
  47. def __init__(self, row, db_type: str):
  48. self._row = row
  49. self._db_type = db_type
  50. def __getitem__(self, key):
  51. if self._db_type == 'mysql':
  52. return self._row[key]
  53. else:
  54. return self._row[key]
  55. def keys(self):
  56. if self._db_type == 'mysql':
  57. return self._row.keys()
  58. else:
  59. return self._row.keys()
  60. class CursorWrapper:
  61. """Wrapper to provide consistent cursor interface for both databases."""
  62. def __init__(self, cursor, db_type: str):
  63. self._cursor = cursor
  64. self._db_type = db_type
  65. def execute(self, sql: str, params: tuple = None):
  66. """Execute SQL with parameter conversion."""
  67. if self._db_type == 'mysql':
  68. # Convert ? placeholders to %s for MySQL
  69. sql = sql.replace('?', '%s')
  70. if params:
  71. self._cursor.execute(sql, params)
  72. else:
  73. self._cursor.execute(sql)
  74. return self
  75. def fetchone(self) -> Optional[RowWrapper]:
  76. row = self._cursor.fetchone()
  77. if row is None:
  78. return None
  79. return RowWrapper(row, self._db_type)
  80. def fetchall(self) -> list:
  81. rows = self._cursor.fetchall()
  82. return [RowWrapper(row, self._db_type) for row in rows]
  83. @property
  84. def lastrowid(self):
  85. return self._cursor.lastrowid
  86. @property
  87. def rowcount(self):
  88. return self._cursor.rowcount
  89. class ConnectionWrapper:
  90. """Wrapper to provide consistent connection interface."""
  91. def __init__(self, conn, db_type: str):
  92. self._conn = conn
  93. self._db_type = db_type
  94. def cursor(self) -> CursorWrapper:
  95. return CursorWrapper(self._conn.cursor(), self._db_type)
  96. def commit(self):
  97. self._conn.commit()
  98. def rollback(self):
  99. self._conn.rollback()
  100. def close(self):
  101. self._conn.close()
  102. def execute(self, sql: str, params: tuple = None):
  103. cursor = self.cursor()
  104. cursor.execute(sql, params)
  105. return cursor
  106. @contextmanager
  107. def get_db_connection() -> Generator[ConnectionWrapper, None, None]:
  108. """
  109. Context manager for database connections.
  110. Ensures proper connection cleanup.
  111. """
  112. if db_config.db_type == 'mysql':
  113. conn = _get_mysql_connection()
  114. else:
  115. conn = _get_sqlite_connection()
  116. wrapped = ConnectionWrapper(conn, db_config.db_type)
  117. try:
  118. yield wrapped
  119. wrapped.commit()
  120. except Exception:
  121. wrapped.rollback()
  122. raise
  123. finally:
  124. wrapped.close()
  125. def init_database() -> None:
  126. """
  127. Initialize database and create tables if they don't exist.
  128. """
  129. if db_config.db_type == 'mysql':
  130. _init_mysql_database()
  131. else:
  132. _init_sqlite_database()
  133. def _init_mysql_database() -> None:
  134. """Initialize MySQL database tables."""
  135. import pymysql
  136. # First, create database if not exists
  137. conn = pymysql.connect(
  138. host=db_config.host,
  139. port=db_config.port,
  140. user=db_config.user,
  141. password=db_config.password,
  142. charset='utf8mb4'
  143. )
  144. try:
  145. with conn.cursor() as cursor:
  146. cursor.execute(f"CREATE DATABASE IF NOT EXISTS `{db_config.database}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci")
  147. conn.commit()
  148. finally:
  149. conn.close()
  150. # Now create tables
  151. with get_db_connection() as conn:
  152. cursor = conn.cursor()
  153. # Create users table
  154. cursor.execute("""
  155. CREATE TABLE IF NOT EXISTS users (
  156. id VARCHAR(36) PRIMARY KEY,
  157. username VARCHAR(255) NOT NULL UNIQUE,
  158. email VARCHAR(255) NOT NULL UNIQUE,
  159. password_hash VARCHAR(255) NOT NULL,
  160. role VARCHAR(50) NOT NULL DEFAULT 'annotator',
  161. oauth_provider VARCHAR(50),
  162. oauth_id VARCHAR(255),
  163. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  164. updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
  165. INDEX idx_users_username (username),
  166. INDEX idx_users_email (email),
  167. INDEX idx_users_oauth (oauth_provider, oauth_id)
  168. ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
  169. """)
  170. # Create projects table
  171. cursor.execute("""
  172. CREATE TABLE IF NOT EXISTS projects (
  173. id VARCHAR(36) PRIMARY KEY,
  174. name VARCHAR(255) NOT NULL,
  175. description TEXT,
  176. config TEXT NOT NULL,
  177. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
  178. ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
  179. """)
  180. # Create tasks table
  181. cursor.execute("""
  182. CREATE TABLE IF NOT EXISTS tasks (
  183. id VARCHAR(36) PRIMARY KEY,
  184. project_id VARCHAR(36) NOT NULL,
  185. name VARCHAR(255) NOT NULL,
  186. data LONGTEXT NOT NULL,
  187. status VARCHAR(50) DEFAULT 'pending',
  188. assigned_to VARCHAR(36),
  189. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  190. FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE,
  191. INDEX idx_tasks_project (project_id),
  192. INDEX idx_tasks_status (status)
  193. ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
  194. """)
  195. # Create annotations table
  196. cursor.execute("""
  197. CREATE TABLE IF NOT EXISTS annotations (
  198. id VARCHAR(36) PRIMARY KEY,
  199. task_id VARCHAR(36) NOT NULL,
  200. user_id VARCHAR(36) NOT NULL,
  201. result LONGTEXT NOT NULL,
  202. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  203. updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
  204. FOREIGN KEY (task_id) REFERENCES tasks(id) ON DELETE CASCADE,
  205. INDEX idx_annotations_task (task_id),
  206. INDEX idx_annotations_user (user_id)
  207. ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
  208. """)
  209. logger.info("MySQL 数据库初始化完成")
  210. def _init_sqlite_database() -> None:
  211. """Initialize SQLite database tables."""
  212. with get_db_connection() as conn:
  213. cursor = conn.cursor()
  214. # Create users table
  215. cursor.execute("""
  216. CREATE TABLE IF NOT EXISTS users (
  217. id TEXT PRIMARY KEY,
  218. username TEXT NOT NULL UNIQUE,
  219. email TEXT NOT NULL UNIQUE,
  220. password_hash TEXT NOT NULL,
  221. role TEXT NOT NULL DEFAULT 'annotator',
  222. oauth_provider TEXT,
  223. oauth_id TEXT,
  224. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  225. updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
  226. )
  227. """)
  228. cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_username ON users(username)")
  229. cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)")
  230. cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_oauth ON users(oauth_provider, oauth_id)")
  231. # Create projects table
  232. cursor.execute("""
  233. CREATE TABLE IF NOT EXISTS projects (
  234. id TEXT PRIMARY KEY,
  235. name TEXT NOT NULL,
  236. description TEXT,
  237. config TEXT NOT NULL,
  238. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
  239. )
  240. """)
  241. # Create tasks table
  242. cursor.execute("""
  243. CREATE TABLE IF NOT EXISTS tasks (
  244. id TEXT PRIMARY KEY,
  245. project_id TEXT NOT NULL,
  246. name TEXT NOT NULL,
  247. data TEXT NOT NULL,
  248. status TEXT DEFAULT 'pending',
  249. assigned_to TEXT,
  250. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  251. FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE
  252. )
  253. """)
  254. # Create annotations table
  255. cursor.execute("""
  256. CREATE TABLE IF NOT EXISTS annotations (
  257. id TEXT PRIMARY KEY,
  258. task_id TEXT NOT NULL,
  259. user_id TEXT NOT NULL,
  260. result TEXT NOT NULL,
  261. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  262. updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  263. FOREIGN KEY (task_id) REFERENCES tasks(id) ON DELETE CASCADE
  264. )
  265. """)
  266. logger.info("SQLite 数据库初始化完成")
  267. def get_db():
  268. """
  269. Get a database connection (legacy support).
  270. Note: Caller is responsible for closing the connection.
  271. """
  272. if db_config.db_type == 'mysql':
  273. conn = _get_mysql_connection()
  274. else:
  275. conn = _get_sqlite_connection()
  276. return ConnectionWrapper(conn, db_config.db_type)