浏览代码

将数据库改为pgsql

lxylxy123321 1 周之前
父节点
当前提交
ac9e80b429
共有 5 个文件被更改,包括 30 次插入15 次删除
  1. 3 3
      backend/.env.docker
  2. 1 1
      backend/app/config.py
  3. 3 10
      backend/app/core/db.py
  4. 1 0
      backend/requirements.txt
  5. 22 1
      docker-compose.yml

+ 3 - 3
backend/.env.docker

@@ -5,10 +5,10 @@ BACKEND_ENV=production
 BACKEND_LOG_LEVEL=INFO
 BACKEND_LOG_LEVEL=INFO
 BACKEND_CORS_ORIGINS=http://localhost:3000
 BACKEND_CORS_ORIGINS=http://localhost:3000
 
 
-# Docker 容器内数据库路径
-DATABASE_URL=sqlite+aiosqlite:///root/Fine-tuning/backend/data/finetuning.db
+# PostgreSQL 数据库
+DATABASE_URL=postgresql+asyncpg://finetune:finetune123@postgres:5432/finetuning
 
 
-# Docker 容器内数据目录
+# 容器内数据目录
 DATA_DIR=/root/Fine-tuning/backend/data
 DATA_DIR=/root/Fine-tuning/backend/data
 
 
 DEFAULT_PEFT_METHOD=lora
 DEFAULT_PEFT_METHOD=lora

+ 1 - 1
backend/app/config.py

@@ -60,7 +60,7 @@ class Settings(BaseSettings):
     backend_cors_origins: list[str] = ["http://192.168.91.253:5173"]
     backend_cors_origins: list[str] = ["http://192.168.91.253:5173"]
 
 
     # --- 数据库 ---
     # --- 数据库 ---
-    database_url: str = "sqlite+aiosqlite:///root/Fine-tuning/backend/data/finetuning.db"
+    database_url: str = "postgresql+asyncpg://finetune:finetune123@localhost:5432/finetuning"
 
 
     # --- 训练默认参数 ---
     # --- 训练默认参数 ---
     default_peft_method: str = "lora"
     default_peft_method: str = "lora"

+ 3 - 10
backend/app/core/db.py

@@ -4,8 +4,6 @@ from sqlalchemy import Column, DateTime, Float, Integer, String, Text
 from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
 from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
 from sqlalchemy.orm import declarative_base
 from sqlalchemy.orm import declarative_base
 
 
-from pathlib import Path
-
 from app.config import get_settings
 from app.config import get_settings
 
 
 settings = get_settings()
 settings = get_settings()
@@ -13,7 +11,6 @@ settings = get_settings()
 Base = declarative_base()
 Base = declarative_base()
 
 
 # 延迟创建 engine/session,在首次使用时再实例化
 # 延迟创建 engine/session,在首次使用时再实例化
-# 避免模块导入阶段目录还未创建就尝试连接数据库
 _engine = None
 _engine = None
 _async_session = None
 _async_session = None
 
 
@@ -24,6 +21,7 @@ def _get_engine():
         _engine = create_async_engine(
         _engine = create_async_engine(
             settings.database_url,
             settings.database_url,
             echo=settings.backend_env == "development",
             echo=settings.backend_env == "development",
+            pool_pre_ping=True,
         )
         )
     return _engine
     return _engine
 
 
@@ -39,12 +37,7 @@ async_session = _get_session
 
 
 
 
 async def init_db():
 async def init_db():
-    """创建数据库目录 + 创建所有表(首次启动时调用)。"""
-    # 确保数据库文件目录存在
-    db_path = settings.database_url.removeprefix("sqlite+aiosqlite://")
-    if db_path and not db_path.startswith(":memory"):
-        db_path_obj = Path(db_path) if db_path.startswith("/") else Path("/") / db_path
-        db_path_obj.parent.mkdir(parents=True, exist_ok=True)
+    """创建所有表(首次启动时调用)。"""
     async with _get_engine().begin() as conn:
     async with _get_engine().begin() as conn:
         await conn.run_sync(Base.metadata.create_all)
         await conn.run_sync(Base.metadata.create_all)
 
 
@@ -57,7 +50,7 @@ class TrainingJobModel(Base):
     model_type = Column(String(32), nullable=False)
     model_type = Column(String(32), nullable=False)
     dataset_id = Column(String(36), nullable=False)
     dataset_id = Column(String(36), nullable=False)
     peft_method = Column(String(32), nullable=False)
     peft_method = Column(String(32), nullable=False)
-    task_type = Column(String(32), default="sft")  # sft/dpo/kto/orpo/rm/ppo
+    task_type = Column(String(32), default="sft")
     dataset_template = Column(String(32), default="alpaca")
     dataset_template = Column(String(32), default="alpaca")
 
 
     status = Column(String(32), default="pending")
     status = Column(String(32), default="pending")

+ 1 - 0
backend/requirements.txt

@@ -5,6 +5,7 @@ pydantic-settings>=2.0
 python-dotenv>=1.0
 python-dotenv>=1.0
 sqlalchemy[asyncio]>=2.0
 sqlalchemy[asyncio]>=2.0
 aiosqlite>=0.20.0
 aiosqlite>=0.20.0
+asyncpg>=0.29.0
 alembic>=1.13.0
 alembic>=1.13.0
 python-multipart>=0.0.9
 python-multipart>=0.0.9
 websockets>=12.0
 websockets>=12.0

+ 22 - 1
docker-compose.yml

@@ -1,6 +1,21 @@
 version: "3.8"
 version: "3.8"
 
 
 services:
 services:
+  postgres:
+    image: postgres:16-alpine
+    container_name: finetune-postgres
+    restart: unless-stopped
+    environment:
+      POSTGRES_DB: finetuning
+      POSTGRES_USER: finetune
+      POSTGRES_PASSWORD: finetune123
+    volumes:
+      - pgdata:/var/lib/postgresql/data
+    ports:
+      - "5432:5432"
+    networks:
+      - finetune-net
+
   backend:
   backend:
     build:
     build:
       context: ./backend
       context: ./backend
@@ -10,17 +25,20 @@ services:
     ports:
     ports:
       - "8010:8010"
       - "8010:8010"
     volumes:
     volumes:
-      # 持久化数据和模型(使用绝对路径,避免重建容器后数据丢失)
+      # 持久化数据和模型
       - ./backend/data:/root/Fine-tuning/backend/data
       - ./backend/data:/root/Fine-tuning/backend/data
     env_file:
     env_file:
       - ./backend/.env.docker
       - ./backend/.env.docker
     environment:
     environment:
       - BACKEND_HOST=0.0.0.0
       - BACKEND_HOST=0.0.0.0
       - BACKEND_PORT=8010
       - BACKEND_PORT=8010
+      - DATABASE_URL=postgresql+asyncpg://finetune:finetune123@postgres:5432/finetuning
       # 沐曦 maca 环境变量
       # 沐曦 maca 环境变量
       - MACA_PATH=/opt/maca
       - MACA_PATH=/opt/maca
       - LD_LIBRARY_PATH=/opt/maca/lib:/opt/maca/mxgpu_llvm/lib:/opt/maca/ompi/lib
       - LD_LIBRARY_PATH=/opt/maca/lib:/opt/maca/mxgpu_llvm/lib:/opt/maca/ompi/lib
       - MACA_CLANG_PATH=/opt/maca/mxgpu_llvm/bin
       - MACA_CLANG_PATH=/opt/maca/mxgpu_llvm/bin
+    depends_on:
+      - postgres
     devices:
     devices:
       - /dev/mxcd:/dev/mxcd
       - /dev/mxcd:/dev/mxcd
     privileged: true
     privileged: true
@@ -47,3 +65,6 @@ services:
 networks:
 networks:
   finetune-net:
   finetune-net:
     driver: bridge
     driver: bridge
+
+volumes:
+  pgdata: