database.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. """
  2. Database connection and initialization module.
  3. Supports both SQLite and MySQL databases.
  4. """
  5. import os
  6. import logging
  7. from contextlib import contextmanager
  8. from typing import Generator, Any, Optional
  9. from config import settings
  10. logger = logging.getLogger(__name__)
  11. class DatabaseConfig:
  12. """Database configuration holder."""
  13. def __init__(self):
  14. self.db_type = getattr(settings, 'DATABASE_TYPE', 'sqlite')
  15. if self.db_type == 'mysql':
  16. self.host = getattr(settings, 'MYSQL_HOST', 'localhost')
  17. self.port = getattr(settings, 'MYSQL_PORT', 3306)
  18. self.user = getattr(settings, 'MYSQL_USER', 'root')
  19. self.password = getattr(settings, 'MYSQL_PASSWORD', '')
  20. self.database = getattr(settings, 'MYSQL_DATABASE', 'annotation_platform')
  21. else:
  22. self.db_path = getattr(settings, 'DATABASE_PATH', 'annotation_platform.db')
  23. db_config = DatabaseConfig()
  24. def _get_mysql_connection():
  25. """Get MySQL connection."""
  26. import pymysql
  27. conn = pymysql.connect(
  28. host=db_config.host,
  29. port=db_config.port,
  30. user=db_config.user,
  31. password=db_config.password,
  32. database=db_config.database,
  33. charset='utf8mb4',
  34. cursorclass=pymysql.cursors.DictCursor,
  35. autocommit=False
  36. )
  37. return conn
  38. def _get_sqlite_connection():
  39. """Get SQLite connection."""
  40. import sqlite3
  41. conn = sqlite3.connect(db_config.db_path)
  42. conn.row_factory = sqlite3.Row
  43. conn.execute("PRAGMA foreign_keys = ON")
  44. return conn
  45. class RowWrapper:
  46. """Wrapper to provide consistent row access for both SQLite and MySQL."""
  47. def __init__(self, row, db_type: str):
  48. self._row = row
  49. self._db_type = db_type
  50. def __getitem__(self, key):
  51. if self._db_type == 'mysql':
  52. return self._row[key]
  53. else:
  54. return self._row[key]
  55. def keys(self):
  56. if self._db_type == 'mysql':
  57. return self._row.keys()
  58. else:
  59. return self._row.keys()
  60. class CursorWrapper:
  61. """Wrapper to provide consistent cursor interface for both databases."""
  62. def __init__(self, cursor, db_type: str):
  63. self._cursor = cursor
  64. self._db_type = db_type
  65. def execute(self, sql: str, params: tuple = None):
  66. """Execute SQL with parameter conversion."""
  67. if self._db_type == 'mysql':
  68. # Convert ? placeholders to %s for MySQL
  69. sql = sql.replace('?', '%s')
  70. if params:
  71. self._cursor.execute(sql, params)
  72. else:
  73. self._cursor.execute(sql)
  74. return self
  75. def fetchone(self) -> Optional[RowWrapper]:
  76. row = self._cursor.fetchone()
  77. if row is None:
  78. return None
  79. return RowWrapper(row, self._db_type)
  80. def fetchall(self) -> list:
  81. rows = self._cursor.fetchall()
  82. return [RowWrapper(row, self._db_type) for row in rows]
  83. @property
  84. def lastrowid(self):
  85. return self._cursor.lastrowid
  86. @property
  87. def rowcount(self):
  88. return self._cursor.rowcount
  89. class ConnectionWrapper:
  90. """Wrapper to provide consistent connection interface."""
  91. def __init__(self, conn, db_type: str):
  92. self._conn = conn
  93. self._db_type = db_type
  94. def cursor(self) -> CursorWrapper:
  95. return CursorWrapper(self._conn.cursor(), self._db_type)
  96. def commit(self):
  97. self._conn.commit()
  98. def rollback(self):
  99. self._conn.rollback()
  100. def close(self):
  101. self._conn.close()
  102. def execute(self, sql: str, params: tuple = None):
  103. cursor = self.cursor()
  104. cursor.execute(sql, params)
  105. return cursor
  106. @contextmanager
  107. def get_db_connection() -> Generator[ConnectionWrapper, None, None]:
  108. """
  109. Context manager for database connections.
  110. Ensures proper connection cleanup.
  111. """
  112. if db_config.db_type == 'mysql':
  113. conn = _get_mysql_connection()
  114. else:
  115. conn = _get_sqlite_connection()
  116. wrapped = ConnectionWrapper(conn, db_config.db_type)
  117. try:
  118. yield wrapped
  119. wrapped.commit()
  120. except Exception:
  121. wrapped.rollback()
  122. raise
  123. finally:
  124. wrapped.close()
  125. def init_database() -> None:
  126. """
  127. Initialize database and create tables if they don't exist.
  128. """
  129. if db_config.db_type == 'mysql':
  130. _init_mysql_database()
  131. else:
  132. _init_sqlite_database()
  133. def _init_mysql_database() -> None:
  134. """Initialize MySQL database tables."""
  135. import pymysql
  136. # First, create database if not exists
  137. conn = pymysql.connect(
  138. host=db_config.host,
  139. port=db_config.port,
  140. user=db_config.user,
  141. password=db_config.password,
  142. charset='utf8mb4'
  143. )
  144. try:
  145. with conn.cursor() as cursor:
  146. cursor.execute(f"CREATE DATABASE IF NOT EXISTS `{db_config.database}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci")
  147. conn.commit()
  148. finally:
  149. conn.close()
  150. # Now create tables
  151. with get_db_connection() as conn:
  152. cursor = conn.cursor()
  153. # Create users table
  154. cursor.execute("""
  155. CREATE TABLE IF NOT EXISTS users (
  156. id VARCHAR(36) PRIMARY KEY,
  157. username VARCHAR(255) NOT NULL UNIQUE,
  158. email VARCHAR(255) NOT NULL UNIQUE,
  159. password_hash VARCHAR(255) NOT NULL,
  160. role VARCHAR(50) NOT NULL DEFAULT 'annotator',
  161. oauth_provider VARCHAR(50),
  162. oauth_id VARCHAR(255),
  163. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  164. updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
  165. INDEX idx_users_username (username),
  166. INDEX idx_users_email (email),
  167. INDEX idx_users_oauth (oauth_provider, oauth_id)
  168. ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
  169. """)
  170. # Create projects table
  171. cursor.execute("""
  172. CREATE TABLE IF NOT EXISTS projects (
  173. id VARCHAR(36) PRIMARY KEY,
  174. name VARCHAR(255) NOT NULL,
  175. description TEXT,
  176. config TEXT NOT NULL,
  177. status VARCHAR(20) DEFAULT 'draft',
  178. source VARCHAR(20) DEFAULT 'internal',
  179. task_type VARCHAR(50),
  180. external_id VARCHAR(100),
  181. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  182. updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
  183. INDEX idx_projects_status (status),
  184. INDEX idx_projects_source (source),
  185. INDEX idx_projects_external_id (external_id)
  186. ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
  187. """)
  188. # Create tasks table
  189. cursor.execute("""
  190. CREATE TABLE IF NOT EXISTS tasks (
  191. id VARCHAR(36) PRIMARY KEY,
  192. project_id VARCHAR(36) NOT NULL,
  193. name VARCHAR(255) NOT NULL,
  194. data LONGTEXT NOT NULL,
  195. status VARCHAR(50) DEFAULT 'pending',
  196. assigned_to VARCHAR(36),
  197. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  198. FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE,
  199. INDEX idx_tasks_project (project_id),
  200. INDEX idx_tasks_status (status)
  201. ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
  202. """)
  203. # Create annotations table
  204. cursor.execute("""
  205. CREATE TABLE IF NOT EXISTS annotations (
  206. id VARCHAR(36) PRIMARY KEY,
  207. task_id VARCHAR(36) NOT NULL,
  208. user_id VARCHAR(36) NOT NULL,
  209. result LONGTEXT NOT NULL,
  210. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  211. updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP,
  212. FOREIGN KEY (task_id) REFERENCES tasks(id) ON DELETE CASCADE,
  213. INDEX idx_annotations_task (task_id),
  214. INDEX idx_annotations_user (user_id)
  215. ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
  216. """)
  217. # Create export_jobs table
  218. cursor.execute("""
  219. CREATE TABLE IF NOT EXISTS export_jobs (
  220. id VARCHAR(36) PRIMARY KEY,
  221. project_id VARCHAR(36) NOT NULL,
  222. format VARCHAR(50) NOT NULL,
  223. status VARCHAR(50) DEFAULT 'pending',
  224. status_filter VARCHAR(50) DEFAULT 'all',
  225. include_metadata BOOLEAN DEFAULT TRUE,
  226. file_path TEXT,
  227. error_message TEXT,
  228. created_by VARCHAR(36),
  229. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  230. completed_at TIMESTAMP,
  231. total_tasks INT DEFAULT 0,
  232. exported_tasks INT DEFAULT 0,
  233. FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE,
  234. INDEX idx_export_jobs_project (project_id),
  235. INDEX idx_export_jobs_status (status)
  236. ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
  237. """)
  238. logger.info("MySQL 数据库初始化完成")
  239. def _init_sqlite_database() -> None:
  240. """Initialize SQLite database tables."""
  241. with get_db_connection() as conn:
  242. cursor = conn.cursor()
  243. # Create users table
  244. cursor.execute("""
  245. CREATE TABLE IF NOT EXISTS users (
  246. id TEXT PRIMARY KEY,
  247. username TEXT NOT NULL UNIQUE,
  248. email TEXT NOT NULL UNIQUE,
  249. password_hash TEXT NOT NULL,
  250. role TEXT NOT NULL DEFAULT 'annotator',
  251. oauth_provider TEXT,
  252. oauth_id TEXT,
  253. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  254. updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
  255. )
  256. """)
  257. cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_username ON users(username)")
  258. cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_email ON users(email)")
  259. cursor.execute("CREATE INDEX IF NOT EXISTS idx_users_oauth ON users(oauth_provider, oauth_id)")
  260. # Create projects table
  261. cursor.execute("""
  262. CREATE TABLE IF NOT EXISTS projects (
  263. id TEXT PRIMARY KEY,
  264. name TEXT NOT NULL,
  265. description TEXT,
  266. config TEXT NOT NULL,
  267. status TEXT DEFAULT 'draft',
  268. source TEXT DEFAULT 'internal',
  269. task_type TEXT,
  270. external_id TEXT,
  271. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  272. updated_at TIMESTAMP
  273. )
  274. """)
  275. cursor.execute("CREATE INDEX IF NOT EXISTS idx_projects_status ON projects(status)")
  276. cursor.execute("CREATE INDEX IF NOT EXISTS idx_projects_source ON projects(source)")
  277. cursor.execute("CREATE INDEX IF NOT EXISTS idx_projects_external_id ON projects(external_id)")
  278. # Create tasks table
  279. cursor.execute("""
  280. CREATE TABLE IF NOT EXISTS tasks (
  281. id TEXT PRIMARY KEY,
  282. project_id TEXT NOT NULL,
  283. name TEXT NOT NULL,
  284. data TEXT NOT NULL,
  285. status TEXT DEFAULT 'pending',
  286. assigned_to TEXT,
  287. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  288. FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE
  289. )
  290. """)
  291. # Create annotations table
  292. cursor.execute("""
  293. CREATE TABLE IF NOT EXISTS annotations (
  294. id TEXT PRIMARY KEY,
  295. task_id TEXT NOT NULL,
  296. user_id TEXT NOT NULL,
  297. result TEXT NOT NULL,
  298. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  299. updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  300. FOREIGN KEY (task_id) REFERENCES tasks(id) ON DELETE CASCADE
  301. )
  302. """)
  303. # Create export_jobs table
  304. cursor.execute("""
  305. CREATE TABLE IF NOT EXISTS export_jobs (
  306. id TEXT PRIMARY KEY,
  307. project_id TEXT NOT NULL,
  308. format TEXT NOT NULL,
  309. status TEXT DEFAULT 'pending',
  310. status_filter TEXT DEFAULT 'all',
  311. include_metadata INTEGER DEFAULT 1,
  312. file_path TEXT,
  313. error_message TEXT,
  314. created_by TEXT,
  315. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  316. completed_at TIMESTAMP,
  317. total_tasks INTEGER DEFAULT 0,
  318. exported_tasks INTEGER DEFAULT 0,
  319. FOREIGN KEY (project_id) REFERENCES projects(id) ON DELETE CASCADE
  320. )
  321. """)
  322. cursor.execute("CREATE INDEX IF NOT EXISTS idx_export_jobs_project ON export_jobs(project_id)")
  323. cursor.execute("CREATE INDEX IF NOT EXISTS idx_export_jobs_status ON export_jobs(status)")
  324. logger.info("SQLite 数据库初始化完成")
  325. def get_db():
  326. """
  327. Get a database connection (legacy support).
  328. Note: Caller is responsible for closing the connection.
  329. """
  330. if db_config.db_type == 'mysql':
  331. conn = _get_mysql_connection()
  332. else:
  333. conn = _get_sqlite_connection()
  334. return ConnectionWrapper(conn, db_config.db_type)