"""远程训练入口脚本 — 在算力节点上执行。""" import asyncio import json import os import sys import signal from pathlib import Path # 禁用 FlashAttention os.environ["PYTORCH_NO_FLASH"] = "1" os.environ["FLASH_ATTENTION_ENABLED"] = "0" async def run_training(job_id: str, model_id: str, model_type: str, dataset_id: str, config: dict): """执行单个训练任务(远程调用入口)。""" from app.config import get_settings from app.core.logging import logger settings = get_settings() # 查找数据集 from app.core.db import async_session, DatasetRecord from sqlalchemy import select dataset_path = None async with async_session() as session: result = await session.execute(select(DatasetRecord).where( (DatasetRecord.id == dataset_id) | (DatasetRecord.name == dataset_id) )) record = result.scalar_one_or_none() if record: dataset_path = record.file_path if not dataset_path: # 尝试 uploads 目录 upload_path = settings.uploads_dir / dataset_id if upload_path.exists(): dataset_path = str(upload_path) if not dataset_path: raise FileNotFoundError(f"Dataset not found: {dataset_id}") # 预处理 processed_path = str(settings.processed_dir / f"{job_id}_processed.jsonl") task_type = config.get("task_type", "sft") template = config.get("dataset_template", "alpaca") # 选择引擎 if model_type == "vision": from app.engines.vision_engine import vision_engine engine = vision_engine elif model_type == "multimodal": from app.engines.multimodal_engine import multimodal_engine engine = multimodal_engine else: from app.engines.text_engine import text_engine engine = text_engine peft_method = config.get("peft_method", "lora") # 预处理数据集 await engine.preprocess_dataset(dataset_path, processed_path, task_type=task_type, template=template) # 加载模型 await engine.load_model(model_id, quantization="4bit" if peft_method == "qlora" else None) # 构建 PEFT 配置 peft_config = engine.get_peft_config(peft_method, config) # 训练 adapter_path = await engine.train( job_id=job_id, dataset_path=processed_path, peft_config=peft_config, training_args=config, ) logger.info(f"Remote training completed: {job_id} -> {adapter_path}") return adapter_path def main(): """命令行入口:python -m app.engines.remote_train """ if len(sys.argv) < 6: print("Usage: python -m app.engines.remote_train ") sys.exit(1) job_id = sys.argv[1] model_id = sys.argv[2] model_type = sys.argv[3] dataset_id = sys.argv[4] config = json.loads(sys.argv[5]) asyncio.run(run_training(job_id, model_id, model_type, dataset_id, config)) if __name__ == "__main__": main()