remote_train.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. """远程训练入口脚本 — 在算力节点上执行。"""
  2. import asyncio
  3. import json
  4. import os
  5. import sys
  6. import signal
  7. from pathlib import Path
  8. # 禁用 FlashAttention
  9. os.environ["PYTORCH_NO_FLASH"] = "1"
  10. os.environ["FLASH_ATTENTION_ENABLED"] = "0"
  11. async def run_training(job_id: str, model_id: str, model_type: str, dataset_id: str, config: dict):
  12. """执行单个训练任务(远程调用入口)。"""
  13. from app.config import get_settings
  14. from app.core.logging import logger
  15. settings = get_settings()
  16. # 查找数据集
  17. from app.core.db import async_session, DatasetRecord
  18. from sqlalchemy import select
  19. dataset_path = None
  20. async with async_session() as session:
  21. result = await session.execute(select(DatasetRecord).where(
  22. (DatasetRecord.id == dataset_id) | (DatasetRecord.name == dataset_id)
  23. ))
  24. record = result.scalar_one_or_none()
  25. if record:
  26. dataset_path = record.file_path
  27. if not dataset_path:
  28. # 尝试 uploads 目录
  29. upload_path = settings.uploads_dir / dataset_id
  30. if upload_path.exists():
  31. dataset_path = str(upload_path)
  32. if not dataset_path:
  33. raise FileNotFoundError(f"Dataset not found: {dataset_id}")
  34. # 预处理
  35. processed_path = str(settings.processed_dir / f"{job_id}_processed.jsonl")
  36. task_type = config.get("task_type", "sft")
  37. template = config.get("dataset_template", "alpaca")
  38. # 选择引擎
  39. if model_type == "vision":
  40. from app.engines.vision_engine import vision_engine
  41. engine = vision_engine
  42. elif model_type == "multimodal":
  43. from app.engines.multimodal_engine import multimodal_engine
  44. engine = multimodal_engine
  45. else:
  46. from app.engines.text_engine import text_engine
  47. engine = text_engine
  48. peft_method = config.get("peft_method", "lora")
  49. # 预处理数据集
  50. await engine.preprocess_dataset(dataset_path, processed_path, task_type=task_type, template=template)
  51. # 加载模型
  52. await engine.load_model(model_id, quantization="4bit" if peft_method == "qlora" else None)
  53. # 构建 PEFT 配置
  54. peft_config = engine.get_peft_config(peft_method, config)
  55. # 训练
  56. adapter_path = await engine.train(
  57. job_id=job_id,
  58. dataset_path=processed_path,
  59. peft_config=peft_config,
  60. training_args=config,
  61. )
  62. logger.info(f"Remote training completed: {job_id} -> {adapter_path}")
  63. return adapter_path
  64. def main():
  65. """命令行入口:python -m app.engines.remote_train <job_id> <model_id> <model_type> <dataset_id> <config_json>"""
  66. if len(sys.argv) < 6:
  67. print("Usage: python -m app.engines.remote_train <job_id> <model_id> <model_type> <dataset_id> <config_json>")
  68. sys.exit(1)
  69. job_id = sys.argv[1]
  70. model_id = sys.argv[2]
  71. model_type = sys.argv[3]
  72. dataset_id = sys.argv[4]
  73. config = json.loads(sys.argv[5])
  74. asyncio.run(run_training(job_id, model_id, model_type, dataset_id, config))
  75. if __name__ == "__main__":
  76. main()