| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203 |
- import asyncio
- import json
- import uuid
- from datetime import datetime, timezone
- from typing import Any
- from app.config import get_settings
- from app.core.db import async_session, TrainingJobModel
- from app.core.job_queue import JobStatus, TrainingJob, job_queue
- from app.core.logging import logger
- from sqlalchemy import select
- settings = get_settings()
- async def create_training_job(config: dict[str, Any]) -> dict[str, Any]:
- """校验配置、创建任务记录、加入队列。"""
- job_id = str(uuid.uuid4())
- model_id = config.get("model_id", "")
- model_type = config.get("model_type", "text")
- dataset_id = config.get("dataset_id", "")
- peft_method = config.get("peft_method", "lora")
- task_type = config.get("task_type", "sft")
- dataset_template = config.get("dataset_template", "alpaca")
- # 写入数据库
- record = TrainingJobModel(
- id=job_id,
- model_id=model_id,
- model_type=model_type,
- dataset_id=dataset_id,
- peft_method=peft_method,
- task_type=task_type,
- dataset_template=dataset_template,
- status="pending",
- epochs=config.get("epochs", 3),
- batch_size=config.get("batch_size", 4),
- gradient_accumulation=config.get("gradient_accumulation", 4),
- learning_rate=config.get("learning_rate", 2e-4),
- max_seq_length=config.get("max_seq_length", 2048),
- warmup_ratio=config.get("warmup_ratio", 0.05),
- save_strategy=config.get("save_strategy", "epoch"),
- eval_strategy=config.get("eval_strategy", "epoch"),
- eval_steps=config.get("eval_steps", 100),
- lora_r=config.get("lora_r", 16),
- lora_alpha=config.get("lora_alpha", 32),
- lora_dropout=config.get("lora_dropout", 0.05),
- lora_target_modules=config.get("lora_target_modules", "all-linear"),
- qlora_bits=config.get("qlora_bits", 4),
- created_at=datetime.now(timezone.utc),
- )
- async with async_session() as session:
- session.add(record)
- await session.commit()
- # 加入 JobQueue
- # 如果启用 DeepSpeed,生成配置文件
- if config.get("deepspeed", False):
- ds_config_path = _generate_deepspeed_config()
- config["deepspeed"] = ds_config_path
- job = TrainingJob(
- id=job_id,
- model_id=model_id,
- model_type=model_type,
- peft_method=peft_method,
- dataset_id=dataset_id,
- config=config,
- status=JobStatus.PENDING,
- )
- await job_queue.enqueue(job_id, job)
- logger.info(f"Training job created: {job_id}")
- return {
- "id": job_id,
- "model_id": model_id,
- "model_type": model_type,
- "peft_method": peft_method,
- "status": "pending",
- "created_at": record.created_at.isoformat(),
- }
- async def list_training_jobs() -> list[dict[str, Any]]:
- """列出所有训练任务。"""
- async with async_session() as session:
- result = await session.execute(select(TrainingJobModel).order_by(TrainingJobModel.created_at.desc()))
- records = result.scalars().all()
- return [_job_to_dict(r) for r in records]
- async def get_training_job(job_id: str) -> dict[str, Any] | None:
- """获取指定任务详情。"""
- async with async_session() as session:
- result = await session.execute(select(TrainingJobModel).where(TrainingJobModel.id == job_id))
- record = result.scalar_one_or_none()
- if record:
- return _job_to_dict(record)
- return None
- async def cancel_training_job(job_id: str) -> dict[str, Any]:
- """向运行中的任务发送取消信号。"""
- await job_queue.cancel(job_id)
- async with async_session() as session:
- result = await session.execute(select(TrainingJobModel).where(TrainingJobModel.id == job_id))
- record = result.scalar_one_or_none()
- if record:
- record.status = "cancelled"
- record.finished_at = datetime.now(timezone.utc)
- await session.commit()
- logger.info(f"Job cancelled: {job_id}")
- return {"status": "cancelled"}
- async def update_job_in_db(job):
- """JobQueue 回调:同步 job 状态到数据库。"""
- try:
- async with async_session() as session:
- result = await session.execute(select(TrainingJobModel).where(TrainingJobModel.id == job.id))
- record = result.scalar_one_or_none()
- if record:
- record.status = job.status.value if hasattr(job.status, "value") else str(job.status)
- record.progress = job.progress
- record.current_epoch = job.current_epoch
- record.current_step = job.current_step
- record.total_steps = job.total_steps
- record.loss = job.loss
- record.adapter_path = job.adapter_path
- record.error_message = job.error_message
- if job.status == JobStatus.TRAINING and not record.started_at:
- record.started_at = datetime.now(timezone.utc)
- if job.status.is_terminal:
- record.finished_at = datetime.now(timezone.utc)
- await session.commit()
- except Exception as e:
- logger.error(f"Failed to update job {job.id} in DB: {e}")
- def _job_to_dict(r) -> dict[str, Any]:
- return {
- "id": r.id,
- "model_id": r.model_id,
- "model_type": r.model_type,
- "peft_method": r.peft_method,
- "status": r.status,
- "progress": r.progress or 0.0,
- "current_epoch": r.current_epoch or 0,
- "current_step": r.current_step or 0,
- "total_steps": r.total_steps or 0,
- "loss": r.loss,
- "created_at": r.created_at.isoformat() if r.created_at else "",
- "started_at": r.started_at.isoformat() if r.started_at else None,
- "finished_at": r.finished_at.isoformat() if r.finished_at else None,
- "error_message": r.error_message,
- "adapter_path": r.adapter_path,
- }
- def _generate_deepspeed_config(stage: int = 2) -> str:
- """生成 DeepSpeed 配置文件,返回文件路径。"""
- import json
- from app.config import get_settings
- settings = get_settings()
- ds_config = {
- "fp16": {"enabled": True},
- "zero_optimization": {
- "stage": stage,
- "offload_optimizer": {"device": "cpu", "pin_memory": True},
- "offload_param": {"device": "cpu", "pin_memory": True},
- "overlap_comm": True,
- "contiguous_gradients": True,
- "reduce_bucket_size": "auto",
- "stage3_prefetch_bucket_size": "auto",
- "stage3_param_persistence_threshold": "auto",
- } if stage == 3 else {
- "stage": stage,
- "offload_optimizer": {"device": "cpu", "pin_memory": True},
- "allgather_partitions": True,
- "allgather_bucket_size": 2e8,
- "overlap_comm": True,
- "reduce_scatter": True,
- "reduce_bucket_size": 2e8,
- "contiguous_gradients": True,
- },
- "gradient_accumulation_steps": "auto",
- "gradient_clipping": "auto",
- "steps_per_print": 10,
- "train_batch_size": "auto",
- "train_micro_batch_size_per_gpu": "auto",
- "wall_clock_breakdown": False,
- }
- config_path = settings.data_dir / "deepspeed_config.json"
- with open(config_path, "w") as f:
- json.dump(ds_config, f, indent=2)
- logger.info(f"DeepSpeed config generated: {config_path}")
- return str(config_path)
|