| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667 |
- import { useState, useEffect, useRef, useCallback, memo } from 'react'
- import api, { TrainingJob, ModelInfo, DatasetInfo } from '../api/client'
- import { wsManager } from '../api/websocket'
- import { Train } from 'lucide-react'
- const MODEL_TYPES = [
- { value: 'text', label: '文本 (LLaMA/Qwen)' },
- { value: 'vision', label: '视觉 (ViT/CLIP)' },
- { value: 'multimodal', label: '多模态 (LLaVA/Qwen-VL)' },
- ]
- const PEFT_METHODS = [
- { value: 'lora', label: 'LoRA' },
- { value: 'qlora', label: 'QLoRA (推荐)' },
- { value: 'adalora', label: 'AdaLoRA' },
- ]
- const TASK_TYPES = [
- { value: 'sft', label: 'SFT (监督微调)' },
- { value: 'dpo', label: 'DPO (直接偏好优化)' },
- { value: 'ppo', label: 'PPO (近端策略优化)' },
- ]
- const DATASET_TEMPLATES = [
- { value: 'auto', label: 'Auto (自动检测)' },
- { value: 'alpaca', label: 'Alpaca (instruction/input/output)' },
- { value: 'sharegpt', label: 'ShareGPT (conversations)' },
- { value: 'raw', label: 'Raw (text 字段)' },
- ]
- // --- 预设值常量 ---
- const EPOCH_PRESETS = [
- { value: 1, label: '1 (快速验证)' },
- { value: 2, label: '2' },
- { value: 3, label: '3 (推荐)' },
- { value: 5, label: '5' },
- { value: 10, label: '10 (充分训练)' },
- ]
- const BATCH_SIZE_PRESETS = [
- { value: 1, label: '1 (显存受限)' },
- { value: 2, label: '2' },
- { value: 4, label: '4' },
- { value: 8, label: '8' },
- { value: 16, label: '16 (推荐)' },
- { value: 32, label: '32' },
- { value: 64, label: '64' },
- ]
- const LR_PRESETS = [
- { value: '1e-3', label: '1e-3 (较大)' },
- { value: '5e-4', label: '5e-4' },
- { value: '2e-4', label: '2e-4 (推荐)' },
- { value: '1e-4', label: '1e-4' },
- { value: '5e-5', label: '5e-5 (较小)' },
- { value: '1e-5', label: '1e-5' },
- ]
- const LORA_R_PRESETS = [
- { value: 4, label: '4 (轻量)' },
- { value: 8, label: '8' },
- { value: 16, label: '16 (推荐)' },
- { value: 32, label: '32' },
- { value: 64, label: '64 (高精度)' },
- ]
- const SEQ_LEN_PRESETS = [
- { value: 512, label: '512 (短文本)' },
- { value: 1024, label: '1024' },
- { value: 2048, label: '2048 (推荐)' },
- { value: 4096, label: '4096 (长文本)' },
- ]
- const GRAD_ACC_PRESETS = [
- { value: 1, label: '1 (无累积)' },
- { value: 2, label: '2' },
- { value: 4, label: '4 (推荐)' },
- { value: 8, label: '8' },
- { value: 16, label: '16' },
- ]
- const NUM_GPUS_PRESETS = [
- { value: 1, label: '1 GPU (单卡)' },
- { value: 2, label: '2 GPU (DDP 数据并行)' },
- ]
- // --- 通用 Select 组件 ---
- interface SelectProps {
- options: { value: string | number; label: string }[]
- value: string | number
- onChange: (value: string | number) => void
- placeholder?: string
- }
- function Select({ options, value, onChange, placeholder }: SelectProps) {
- return (
- <select
- value={value}
- onChange={e => onChange(e.target.value)}
- style={{
- width: '100%', padding: '6px 8px', borderRadius: 4,
- border: '1px solid #d0d0d0', boxSizing: 'border-box',
- fontSize: 13, background: '#fff', cursor: 'pointer',
- }}
- >
- {placeholder && <option value="" disabled>{placeholder}</option>}
- {options.map(o => (
- <option key={o.value} value={o.value}>{o.label}</option>
- ))}
- </select>
- )
- }
- // --- 可搜索下拉框组件 ---
- interface SearchableSelectProps {
- options: { value: string; label: string; subtitle?: string }[]
- value: string
- onChange: (value: string) => void
- placeholder: string
- loading?: boolean
- }
- function SearchableSelect({ options, value, onChange, placeholder, loading }: SearchableSelectProps) {
- const [open, setOpen] = useState(false)
- const [filter, setFilter] = useState('')
- const wrapperRef = useRef<HTMLDivElement>(null)
- const inputRef = useRef<HTMLInputElement>(null)
- useEffect(() => {
- const handler = (e: MouseEvent) => {
- if (wrapperRef.current && !wrapperRef.current.contains(e.target as Node)) {
- setOpen(false)
- }
- }
- if (open) document.addEventListener('mousedown', handler)
- return () => document.removeEventListener('mousedown', handler)
- }, [open])
- useEffect(() => {
- if (open) {
- setFilter('')
- setTimeout(() => inputRef.current?.focus(), 0)
- }
- }, [open])
- const selectedLabel = options.find(o => o.value === value)?.label ?? ''
- const filtered = options.filter(o =>
- o.label.toLowerCase().includes(filter.toLowerCase()) || o.value.toLowerCase().includes(filter.toLowerCase())
- )
- const handleSelect = useCallback((val: string) => {
- onChange(val)
- setOpen(false)
- }, [onChange])
- const handleKeyDown = useCallback((e: React.KeyboardEvent) => {
- if (e.key === 'Enter' && filtered.length === 1) {
- handleSelect(filtered[0].value)
- } else if (e.key === 'Escape') {
- setOpen(false)
- }
- }, [filtered, handleSelect])
- const toggleOpen = useCallback(() => setOpen(prev => !prev), [])
- return (
- <div ref={wrapperRef} style={{ position: 'relative' }}>
- <div
- onClick={toggleOpen}
- style={{
- padding: '6px 8px', borderRadius: 4,
- border: '1px solid #d0d0d0', cursor: 'pointer', background: '#fff',
- minHeight: 32, display: 'flex', alignItems: 'center', justifyContent: 'space-between',
- fontSize: 13, transition: 'border-color 0.2s',
- }}
- >
- <span style={{ color: value ? '#333' : '#999' }}>
- {value ? selectedLabel : placeholder}
- </span>
- <span style={{ color: '#999', fontSize: 11 }}>{open ? '▲' : '▼'}</span>
- </div>
- {open && (
- <div style={{
- position: 'absolute', top: '100%', left: 0, right: 0, background: '#fff',
- border: '1px solid #d0d0d0', borderRadius: 4, boxShadow: '0 4px 12px rgba(0,0,0,0.12)',
- zIndex: 1000, marginTop: 2, maxHeight: 240, display: 'flex', flexDirection: 'column',
- }}>
- <input
- ref={inputRef}
- value={filter}
- onChange={e => setFilter(e.target.value)}
- onKeyDown={handleKeyDown}
- placeholder="搜索..."
- style={{
- padding: '6px 8px', border: 'none', borderBottom: '1px solid #eee',
- outline: 'none', fontSize: 13,
- }}
- />
- <div style={{ overflowY: 'auto', flex: 1 }}>
- {loading && (
- <div style={{ padding: '8px 12px', color: '#999', fontSize: 13 }}>加载中...</div>
- )}
- {!loading && filtered.length === 0 && (
- <div style={{ padding: '8px 12px', color: '#999', fontSize: 13 }}>无匹配项</div>
- )}
- {!loading && filtered.map(opt => (
- <div
- key={opt.value}
- onClick={() => handleSelect(opt.value)}
- style={{
- padding: '8px 12px', cursor: 'pointer',
- background: opt.value === value ? '#14b8a6' : 'transparent',
- color: opt.value === value ? '#fff' : '#333',
- fontSize: 13,
- }}
- onMouseEnter={e => { if (opt.value !== value) e.currentTarget.style.background = '#f5f5f5' }}
- onMouseLeave={e => { if (opt.value !== value) e.currentTarget.style.background = 'transparent' }}
- >
- <div>{opt.label}</div>
- {opt.subtitle && (
- <div style={{ fontSize: 11, color: opt.value === value ? 'rgba(255,255,255,0.7)' : '#999', marginTop: 2 }}>
- {opt.subtitle}
- </div>
- )}
- </div>
- ))}
- </div>
- </div>
- )}
- </div>
- )
- }
- // --- 任务状态颜色 ---
- const statusColor = (status: string) => {
- switch (status) {
- case 'completed': return '#10b981'
- case 'failed': return '#f43f5e'
- case 'training': return '#0ea5e9'
- case 'pending': case 'queued': return '#f59e0b'
- case 'preprocessing': return '#a855f7'
- case 'cancelled': return '#94a3b8'
- default: return '#64748b'
- }
- }
- const statusLabel = (status: string) => {
- switch (status) {
- case 'completed': return '已完成'
- case 'failed': return '失败'
- case 'training': return '训练中'
- case 'pending': return '等待中'
- case 'queued': return '排队中'
- case 'preprocessing': return '预处理'
- case 'cancelled': return '已取消'
- default: return status
- }
- }
- // --- 任务行(memo) ---
- const JobRow = memo(function JobRow({ j, onCancel, datasets }: { j: TrainingJob; onCancel: (id: string) => void; datasets: DatasetInfo[] }) {
- const modelShort = j.model_id.split('/').pop() || j.model_id
- const dsName = datasets?.find(d => d.id === j.dataset_id)?.name
- || j.dataset_id?.split('/').pop()
- || '-'
- return (
- <tr style={{
- borderBottom: '1px solid #f0f0f0',
- transition: 'background 0.15s ease',
- }}
- onMouseEnter={e => { e.currentTarget.style.background = '#f0fdfa' }}
- onMouseLeave={e => { e.currentTarget.style.background = 'transparent' }}
- >
- <td style={{ padding: '12px 12px' }}>
- <div style={{ fontSize: 13, fontWeight: 500, color: '#1e293b' }}>{modelShort}</div>
- <div style={{ fontFamily: 'monospace', fontSize: 11, color: '#94a3b8', marginTop: 2 }}>{j.id.slice(0, 8)}...</div>
- </td>
- <td style={{ padding: '12px 12px', fontSize: 13, textTransform: 'uppercase', color: '#666' }}>{j.peft_method}</td>
- <td style={{ padding: '12px 12px', fontSize: 13, color: '#64748b', maxWidth: 160, overflow: 'hidden', textOverflow: 'ellipsis', whiteSpace: 'nowrap' }} title={j.dataset_id}>{dsName}</td>
- <td style={{ padding: '12px 12px' }}>
- <span style={{
- display: 'inline-block', padding: '3px 10px', borderRadius: 12, fontSize: 12, fontWeight: 600,
- background: statusColor(j.status) + '15', color: statusColor(j.status),
- }}>
- {statusLabel(j.status)}
- </span>
- </td>
- <td style={{ padding: '12px 12px' }}>
- <div style={{ display: 'flex', alignItems: 'center', gap: 10 }}>
- <div style={{
- width: 120, height: 8, background: '#f0f0f0', borderRadius: 4, overflow: 'hidden',
- }}>
- <div style={{
- width: `${Math.min(100, Math.max(0, j.progress ?? 0))}%`, height: '100%', borderRadius: 4,
- background: `linear-gradient(90deg, ${statusColor(j.status)}, ${statusColor(j.status)}cc)`,
- transition: 'width 0.3s ease',
- }} />
- </div>
- <span style={{ fontSize: 12, color: '#666', minWidth: 45, fontWeight: 500 }}>{(j.progress ?? 0).toFixed(1)}%</span>
- </div>
- </td>
- <td style={{ padding: '12px 12px', fontSize: 13, fontFamily: 'monospace', fontWeight: 500 }}>{j.loss?.toFixed(4) ?? '-'}</td>
- <td style={{ padding: '12px 12px', fontSize: 12, color: '#888' }}>Epoch {j.current_epoch}</td>
- <td style={{ padding: '12px 12px' }}>
- {(j.status === 'training' || j.status === 'pending' || j.status === 'queued' || j.status === 'preprocessing') && (
- <button onClick={() => onCancel(j.id)} style={{
- padding: '4px 12px', color: '#f43f5e', border: '1px solid #f43f5e',
- borderRadius: 6, background: 'transparent', cursor: 'pointer',
- fontSize: 12, fontWeight: 500, transition: 'all 0.15s ease',
- }}
- onMouseEnter={e => { e.currentTarget.style.background = '#f43f5e'; e.currentTarget.style.color = '#fff' }}
- onMouseLeave={e => { e.currentTarget.style.background = 'transparent'; e.currentTarget.style.color = '#f43f5e' }}
- >取消</button>
- )}
- </td>
- </tr>
- )
- })
- export function Training() {
- const [modelId, setModelId] = useState('')
- const [modelType, setModelType] = useState('text')
- const [datasetId, setDatasetId] = useState('')
- const [peftMethod, setPeftMethod] = useState('lora')
- const [taskType, setTaskType] = useState('sft')
- const [template, setTemplate] = useState('auto')
- const [epochs, setEpochs] = useState(3)
- const [batchSize, setBatchSize] = useState(16)
- const [lr, setLr] = useState('2e-4')
- const [loraR, setLoraR] = useState(16)
- const [seqLen, setSeqLen] = useState(2048)
- const [gradAcc, setGradAcc] = useState(4)
- const [numGpus, setNumGpus] = useState(1)
- const [jobs, setJobs] = useState<TrainingJob[]>([])
- const [loading, setLoading] = useState(false)
- const [submitting, setSubmitting] = useState(false)
- const [createError, setCreateError] = useState('')
- const [models, setModels] = useState<ModelInfo[]>([])
- const [datasets, setDatasets] = useState<DatasetInfo[]>([])
- const [loadingOptions, setLoadingOptions] = useState(true)
- const fetchOptions = useCallback(() => {
- setLoadingOptions(true)
- Promise.all([
- api.models.list().catch(() => []),
- api.datasets.list().catch(() => []),
- ]).then(([m, d]) => {
- setModels(m)
- setDatasets(d)
- }).finally(() => setLoadingOptions(false))
- }, [])
- useEffect(() => { fetchOptions() }, [fetchOptions])
- const jobsRef = useRef<TrainingJob[]>([])
- // 跟踪已建立 WS 连接的 job 和对应的取消订阅函数
- const wsConnectedRef = useRef<Set<string>>(new Set())
- const wsUnsubsRef = useRef<Map<string, () => void>>(new Map())
- const fetchJobs = () => {
- setLoading(true)
- api.training.list()
- .then(newJobs => {
- const prev = jobsRef.current
- if (JSON.stringify(prev) !== JSON.stringify(newJobs)) {
- setJobs(newJobs)
- jobsRef.current = newJobs
- }
- syncWsConnections(newJobs)
- })
- .catch(() => {
- if (jobsRef.current.length > 0) {
- setJobs([])
- jobsRef.current = []
- }
- syncWsConnections([])
- })
- .finally(() => setLoading(false))
- }
- // 根据 job 状态同步 WebSocket 连接:training/preprocessing 连接,其他断开
- const ACTIVE_STATUSES = new Set(['training', 'preprocessing'])
- function syncWsConnections(currentJobs: TrainingJob[]) {
- const activeIds = new Set(currentJobs.filter(j => ACTIVE_STATUSES.has(j.status)).map(j => j.id))
- // 为新的活跃 job 建立连接并订阅
- for (const jobId of activeIds) {
- if (wsConnectedRef.current.has(jobId)) continue
- wsConnectedRef.current.add(jobId)
- wsManager.connect(jobId)
- const unsub = wsManager.subscribe(jobId, (msg) => {
- if (msg.type === 'progress') {
- setJobs(prev => prev.map(j =>
- j.id === jobId
- ? {
- ...j,
- current_step: (msg.step as number) ?? j.current_step,
- current_epoch: (msg.epoch as number) ?? j.current_epoch,
- total_steps: (msg.total_steps as number) || j.total_steps,
- loss: (msg.loss as number) ?? j.loss,
- progress: j.total_steps
- ? (((msg.step as number) ?? j.current_step) / (j.total_steps || 1)) * 100
- : j.progress,
- }
- : j
- ))
- } else if (msg.type === 'completed') {
- setJobs(prev => prev.map(j =>
- j.id === jobId ? { ...j, status: 'completed', progress: 100 } : j
- ))
- cleanupJobWs(jobId)
- } else if (msg.type === 'error') {
- setJobs(prev => prev.map(j =>
- j.id === jobId
- ? { ...j, status: 'failed', error_message: (msg.message as string) ?? '训练失败' }
- : j
- ))
- cleanupJobWs(jobId)
- }
- })
- wsUnsubsRef.current.set(jobId, unsub)
- }
- // 断开已完成/失败的 job 的连接
- for (const jobId of [...wsConnectedRef.current]) {
- if (!activeIds.has(jobId)) {
- cleanupJobWs(jobId)
- }
- }
- }
- function cleanupJobWs(jobId: string) {
- wsUnsubsRef.current.get(jobId)?.()
- wsUnsubsRef.current.delete(jobId)
- wsConnectedRef.current.delete(jobId)
- wsManager.disconnect(jobId)
- }
- useEffect(() => {
- fetchJobs()
- const interval = setInterval(fetchJobs, 5000)
- return () => {
- clearInterval(interval)
- wsManager.disconnectAll()
- }
- }, [])
- const handleCreate = () => {
- if (!modelId.trim() || !datasetId.trim()) return
- setSubmitting(true)
- setCreateError('')
- const tempId = 'temp-' + Date.now()
- const tempJob: TrainingJob = {
- id: tempId, model_id: modelId, model_type: modelType,
- peft_method: peftMethod, status: 'pending', progress: 0,
- loss: undefined, created_at: new Date().toISOString(),
- started_at: undefined, finished_at: undefined,
- error_message: undefined, adapter_path: undefined,
- current_epoch: 0, current_step: 0, total_steps: 0,
- }
- setJobs(prev => [tempJob, ...prev])
- setLoading(false)
- api.training.create({
- model_id: modelId, model_type: modelType, dataset_id: datasetId,
- peft_method: peftMethod, task_type: taskType, dataset_template: template,
- epochs, batch_size: batchSize, gradient_accumulation: gradAcc,
- max_seq_length: seqLen, learning_rate: parseFloat(lr),
- lora_r: loraR, lora_alpha: loraR * 2, num_gpus: numGpus,
- })
- .then(() => {
- setModelId('')
- setDatasetId('')
- setJobs(prev => prev.filter(j => j.id !== tempId))
- fetchJobs()
- fetchOptions()
- })
- .catch(err => {
- setJobs(prev => prev.filter(j => j.id !== tempId))
- setCreateError(err instanceof Error ? err.message : '创建失败')
- })
- .finally(() => setSubmitting(false))
- }
- const handleCancel = (id: string) => {
- api.training.cancel(id).then(() => fetchJobs()).catch(console.error)
- }
- const modelOptions = models.map(m => ({
- value: m.id, label: m.id, subtitle: `${m.model_type}${m.is_downloaded ? ' ✓ 已下载' : ''}`,
- }))
- const datasetOptions = datasets.map(d => ({
- value: d.id, label: d.name, subtitle: `${d.format} · ${d.record_count} 条`,
- }))
- return (
- <div>
- <h1 style={{ margin: 0, fontSize: 22, fontWeight: 700 }}>训练任务</h1>
- <p style={{ color: '#888', fontSize: 13, margin: '4px 0 16px' }}>创建和管理模型微调任务</p>
- {/* Create form */}
- <div style={{
- background: '#fff', borderRadius: 12, padding: 24,
- boxShadow: '0 1px 3px rgba(0,0,0,0.06)', border: '1px solid rgba(0,0,0,0.04)',
- }}>
- <h2 style={{ margin: '0 0 20px', fontSize: 15, fontWeight: 600 }}>创建训练任务</h2>
- {/* 核心配置 */}
- <div style={{
- fontSize: 13, fontWeight: 600, color: '#14b8a6', marginBottom: 12,
- paddingBottom: 6, borderBottom: '2px solid #ccfbf1',
- }}>核心配置</div>
- <div style={{ display: 'grid', gridTemplateColumns: 'repeat(3, 1fr)', gap: 16, marginBottom: 20 }}>
- <div>
- <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>基础模型</label>
- <SearchableSelect options={modelOptions} value={modelId} onChange={setModelId} placeholder="选择已下载的模型" loading={loadingOptions} />
- </div>
- <div>
- <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>模型类型</label>
- <Select options={MODEL_TYPES} value={modelType} onChange={v => setModelType(String(v))} />
- </div>
- <div>
- <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>训练数据集</label>
- <SearchableSelect options={datasetOptions} value={datasetId} onChange={setDatasetId} placeholder="选择数据集" loading={loadingOptions} />
- </div>
- <div>
- <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>训练方法</label>
- <Select options={TASK_TYPES} value={taskType} onChange={v => setTaskType(String(v))} />
- </div>
- <div>
- <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>数据模板</label>
- <Select options={DATASET_TEMPLATES} value={template} onChange={v => setTemplate(String(v))} />
- </div>
- <div>
- <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>PEFT 方法</label>
- <Select options={PEFT_METHODS} value={peftMethod} onChange={v => setPeftMethod(String(v))} />
- </div>
- </div>
- {/* 训练超参 */}
- <div style={{
- fontSize: 13, fontWeight: 600, color: '#14b8a6', marginBottom: 12,
- paddingBottom: 6, borderBottom: '2px solid #ccfbf1',
- }}>训练超参数</div>
- <div style={{ display: 'grid', gridTemplateColumns: 'repeat(3, 1fr)', gap: 16, marginBottom: 20 }}>
- <div>
- <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>训练轮数 (Epochs)</label>
- <Select options={EPOCH_PRESETS} value={String(epochs)} onChange={v => setEpochs(Number(v))} />
- </div>
- <div>
- <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>批次大小 (Batch Size)</label>
- <Select options={BATCH_SIZE_PRESETS} value={String(batchSize)} onChange={v => setBatchSize(Number(v))} />
- </div>
- <div>
- <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>梯度累积</label>
- <Select options={GRAD_ACC_PRESETS} value={String(gradAcc)} onChange={v => setGradAcc(Number(v))} />
- </div>
- <div>
- <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>
- GPU 数量 {numGpus > 1 && <span style={{ color: '#2563eb', fontSize: 11 }}>(每卡 batch={batchSize})</span>}
- </label>
- <Select options={NUM_GPUS_PRESETS} value={String(numGpus)} onChange={v => setNumGpus(Number(v))} />
- </div>
- <div>
- <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>学习率</label>
- <Select options={LR_PRESETS} value={lr} onChange={v => setLr(String(v))} />
- </div>
- <div>
- <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>最大序列长度</label>
- <Select options={SEQ_LEN_PRESETS} value={String(seqLen)} onChange={v => setSeqLen(Number(v))} />
- </div>
- <div>
- <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>LoRA Rank (R)</label>
- <Select options={LORA_R_PRESETS} value={String(loraR)} onChange={v => setLoraR(Number(v))} />
- </div>
- </div>
- {/* 高级选项 — DeepSpeed 暂不支持(沐曦 GPU 兼容性待验证,使用 DDP 替代) */}
- {/* 错误提示 */}
- {createError && (
- <div style={{
- marginBottom: 16, padding: 12, background: '#fff1f2', borderRadius: 8,
- fontSize: 13, color: '#e11d48', border: '1px solid #fecdd3',
- }}>
- {createError}
- </div>
- )}
- <button
- onClick={handleCreate}
- disabled={submitting || !modelId || !datasetId}
- style={{
- padding: '12px 36px', borderRadius: 8, border: 'none',
- background: '#14b8a6', color: '#fff', cursor: 'pointer',
- opacity: (submitting || !modelId || !datasetId) ? 0.5 : 1,
- fontSize: 14, fontWeight: 600, transition: 'all 0.2s ease',
- }}
- >
- {submitting ? '创建中...' : '启动训练'}
- </button>
- </div>
- {/* Job list */}
- <div style={{ marginTop: 24 }}>
- <div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 12 }}>
- <h2 style={{ margin: 0, fontSize: 15, fontWeight: 600 }}>任务列表</h2>
- <button onClick={fetchJobs} style={{
- padding: '6px 14px', borderRadius: 6, border: '1px solid #d0d0d0',
- background: '#fff', cursor: 'pointer', fontSize: 13, fontWeight: 500,
- }}
- onMouseEnter={e => { e.currentTarget.style.background = '#f5f5f5' }}
- onMouseLeave={e => { e.currentTarget.style.background = '#fff' }}
- >
- 刷新
- </button>
- </div>
- {loading && <p style={{ color: '#999', fontSize: 13 }}>加载中...</p>}
- {!loading && jobs.length === 0 && (
- <div style={{
- padding: 40, textAlign: 'center', color: '#94a3b8', fontSize: 14,
- background: '#fff', borderRadius: 10, boxShadow: '0 1px 3px rgba(0,0,0,0.06)',
- }}>
- <div style={{ marginBottom: 8 }}><Train size={32} color="#94a3b8" strokeWidth={1.5} /></div>
- 暂无训练任务,请先创建训练任务
- </div>
- )}
- {!loading && jobs.length > 0 && (
- <div style={{
- background: '#fff', borderRadius: 10, overflow: 'hidden',
- boxShadow: '0 1px 3px rgba(0,0,0,0.06)', border: '1px solid rgba(0,0,0,0.04)',
- }}>
- <table style={{ width: '100%', borderCollapse: 'collapse', fontSize: 13 }}>
- <thead>
- <tr style={{ background: '#f0fdfa', borderBottom: '2px solid #f1f5f9', textAlign: 'left' }}>
- <th style={{ padding: '10px 12px', fontSize: 12, color: '#666', fontWeight: 600 }}>任务</th>
- <th style={{ padding: '10px 12px', fontSize: 12, color: '#666', fontWeight: 600 }}>PEFT</th>
- <th style={{ padding: '10px 12px', fontSize: 12, color: '#666', fontWeight: 600 }}>数据集</th>
- <th style={{ padding: '10px 12px', fontSize: 12, color: '#666', fontWeight: 600 }}>状态</th>
- <th style={{ padding: '10px 12px', fontSize: 12, color: '#666', fontWeight: 600 }}>进度</th>
- <th style={{ padding: '10px 12px', fontSize: 12, color: '#666', fontWeight: 600 }}>Loss</th>
- <th style={{ padding: '10px 12px', fontSize: 12, color: '#666', fontWeight: 600 }}>轮次</th>
- <th style={{ padding: '10px 12px', fontSize: 12, color: '#666', fontWeight: 600 }}>操作</th>
- </tr>
- </thead>
- <tbody>
- {jobs.map(j => (
- <JobRow key={j.id} j={j} onCancel={handleCancel} datasets={datasets} />
- ))}
- </tbody>
- </table>
- </div>
- )}
- </div>
- </div>
- )
- }
|