import asyncio import json import uuid from datetime import datetime from typing import Any from app.config import get_settings from app.core.db import async_session, TrainingJobModel from app.core.job_queue import JobStatus, TrainingJob, job_queue from app.core.logging import logger from sqlalchemy import select settings = get_settings() async def create_training_job(config: dict[str, Any]) -> dict[str, Any]: """校验配置、创建任务记录、加入队列。""" job_id = str(uuid.uuid4()) model_id = config.get("model_id", "") model_type = config.get("model_type", "text") dataset_id = config.get("dataset_id", "") peft_method = config.get("peft_method", "lora") task_type = config.get("task_type", "sft") dataset_template = config.get("dataset_template", "alpaca") # 写入数据库 record = TrainingJobModel( id=job_id, model_id=model_id, model_type=model_type, dataset_id=dataset_id, peft_method=peft_method, task_type=task_type, dataset_template=dataset_template, status="pending", epochs=config.get("epochs", 3), batch_size=config.get("batch_size", 4), gradient_accumulation=config.get("gradient_accumulation", 4), learning_rate=config.get("learning_rate", 2e-4), max_seq_length=config.get("max_seq_length", 2048), warmup_ratio=config.get("warmup_ratio", 0.05), save_strategy=config.get("save_strategy", "epoch"), eval_strategy=config.get("eval_strategy", "epoch"), eval_steps=config.get("eval_steps", 100), lora_r=config.get("lora_r", 16), lora_alpha=config.get("lora_alpha", 32), lora_dropout=config.get("lora_dropout", 0.05), lora_target_modules=config.get("lora_target_modules", "all-linear"), qlora_bits=config.get("qlora_bits", 4), created_at=datetime.utcnow(), ) async with async_session() as session: session.add(record) await session.commit() # 加入 JobQueue # 如果启用 DeepSpeed,生成配置文件 if config.get("deepspeed", False): ds_config_path = _generate_deepspeed_config() config["deepspeed"] = ds_config_path job = TrainingJob( id=job_id, model_id=model_id, model_type=model_type, peft_method=peft_method, dataset_id=dataset_id, config=config, status=JobStatus.PENDING, ) await job_queue.enqueue(job_id, job) logger.info(f"Training job created: {job_id}") return { "id": job_id, "model_id": model_id, "model_type": model_type, "peft_method": peft_method, "status": "pending", "created_at": record.created_at.isoformat(), } async def list_training_jobs() -> list[dict[str, Any]]: """列出所有训练任务。""" async with async_session() as session: result = await session.execute(select(TrainingJobModel).order_by(TrainingJobModel.created_at.desc())) records = result.scalars().all() return [_job_to_dict(r) for r in records] async def get_training_job(job_id: str) -> dict[str, Any] | None: """获取指定任务详情。""" async with async_session() as session: result = await session.execute(select(TrainingJobModel).where(TrainingJobModel.id == job_id)) record = result.scalar_one_or_none() if record: return _job_to_dict(record) return None async def cancel_training_job(job_id: str) -> dict[str, Any]: """向运行中的任务发送取消信号。""" await job_queue.cancel(job_id) async with async_session() as session: result = await session.execute(select(TrainingJobModel).where(TrainingJobModel.id == job_id)) record = result.scalar_one_or_none() if record: record.status = "cancelled" record.finished_at = datetime.utcnow() await session.commit() logger.info(f"Job cancelled: {job_id}") return {"status": "cancelled"} async def update_job_in_db(job): """JobQueue 回调:同步 job 状态到数据库。""" try: async with async_session() as session: result = await session.execute(select(TrainingJobModel).where(TrainingJobModel.id == job.id)) record = result.scalar_one_or_none() if record: record.status = job.status.value if hasattr(job.status, "value") else str(job.status) record.progress = job.progress record.current_epoch = job.current_epoch record.current_step = job.current_step record.total_steps = job.total_steps record.loss = job.loss record.adapter_path = job.adapter_path record.error_message = job.error_message if job.status == JobStatus.TRAINING and not record.started_at: record.started_at = datetime.utcnow() if job.status.is_terminal: record.finished_at = datetime.utcnow() await session.commit() except Exception as e: logger.error(f"Failed to update job {job.id} in DB: {e}") def _job_to_dict(r) -> dict[str, Any]: return { "id": r.id, "model_id": r.model_id, "model_type": r.model_type, "peft_method": r.peft_method, "status": r.status, "progress": r.progress or 0.0, "current_epoch": r.current_epoch or 0, "current_step": r.current_step or 0, "total_steps": r.total_steps or 0, "loss": r.loss, "created_at": r.created_at.isoformat() if r.created_at else "", "started_at": r.started_at.isoformat() if r.started_at else None, "finished_at": r.finished_at.isoformat() if r.finished_at else None, "error_message": r.error_message, "adapter_path": r.adapter_path, } def _generate_deepspeed_config(stage: int = 2) -> str: """生成 DeepSpeed 配置文件,返回文件路径。""" import json from app.config import get_settings settings = get_settings() ds_config = { "fp16": {"enabled": True}, "zero_optimization": { "stage": stage, "offload_optimizer": {"device": "cpu", "pin_memory": True}, "offload_param": {"device": "cpu", "pin_memory": True}, "overlap_comm": True, "contiguous_gradients": True, "reduce_bucket_size": "auto", "stage3_prefetch_bucket_size": "auto", "stage3_param_persistence_threshold": "auto", } if stage == 3 else { "stage": stage, "offload_optimizer": {"device": "cpu", "pin_memory": True}, "allgather_partitions": True, "allgather_bucket_size": 2e8, "overlap_comm": True, "reduce_scatter": True, "reduce_bucket_size": 2e8, "contiguous_gradients": True, }, "gradient_accumulation_steps": "auto", "gradient_clipping": "auto", "steps_per_print": 10, "train_batch_size": "auto", "train_micro_batch_size_per_gpu": "auto", "wall_clock_breakdown": False, } config_path = settings.data_dir / "deepspeed_config.json" with open(config_path, "w") as f: json.dump(ds_config, f, indent=2) logger.info(f"DeepSpeed config generated: {config_path}") return str(config_path)