|
|
@@ -1,10 +1,194 @@
|
|
|
+import { useState, useEffect } from 'react'
|
|
|
+import api, { TrainingJob } from '../api/client'
|
|
|
+import { wsManager } from '../api/websocket'
|
|
|
+
|
|
|
+const MODEL_TYPES = [
|
|
|
+ { value: 'text', label: '文本 (LLaMA/Qwen)' },
|
|
|
+ { value: 'vision', label: '视觉 (ViT/CLIP)' },
|
|
|
+ { value: 'multimodal', label: '多模态 (LLaVA/Qwen-VL)' },
|
|
|
+]
|
|
|
+
|
|
|
+const PEFT_METHODS = [
|
|
|
+ { value: 'lora', label: 'LoRA' },
|
|
|
+ { value: 'qlora', label: 'QLoRA' },
|
|
|
+ { value: 'ia3', label: 'IA3' },
|
|
|
+ { value: 'adalora', label: 'AdaLoRA' },
|
|
|
+ { value: 'prefix_tuning', label: 'Prefix Tuning' },
|
|
|
+]
|
|
|
+
|
|
|
export function Training() {
|
|
|
+ // Form state
|
|
|
+ const [modelId, setModelId] = useState('')
|
|
|
+ const [modelType, setModelType] = useState('text')
|
|
|
+ const [datasetId, setDatasetId] = useState('')
|
|
|
+ const [peftMethod, setPeftMethod] = useState('lora')
|
|
|
+ const [epochs, setEpochs] = useState(3)
|
|
|
+ const [batchSize, setBatchSize] = useState(4)
|
|
|
+ const [lr, setLr] = useState('2e-4')
|
|
|
+ const [loraR, setLoraR] = useState(16)
|
|
|
+
|
|
|
+ // Job list
|
|
|
+ const [jobs, setJobs] = useState<TrainingJob[]>([])
|
|
|
+ const [loading, setLoading] = useState(false)
|
|
|
+ const [submitting, setSubmitting] = useState(false)
|
|
|
+
|
|
|
+ const fetchJobs = () => {
|
|
|
+ setLoading(true)
|
|
|
+ api.training.list()
|
|
|
+ .then(setJobs)
|
|
|
+ .catch(() => setJobs([]))
|
|
|
+ .finally(() => setLoading(false))
|
|
|
+ }
|
|
|
+
|
|
|
+ useEffect(() => {
|
|
|
+ fetchJobs()
|
|
|
+ }, [])
|
|
|
+
|
|
|
+ const handleCreate = () => {
|
|
|
+ if (!modelId.trim() || !datasetId.trim()) return
|
|
|
+ setSubmitting(true)
|
|
|
+ api.training.create({
|
|
|
+ model_id: modelId,
|
|
|
+ model_type: modelType,
|
|
|
+ dataset_id: datasetId,
|
|
|
+ peft_method: peftMethod,
|
|
|
+ epochs,
|
|
|
+ batch_size: batchSize,
|
|
|
+ learning_rate: parseFloat(lr),
|
|
|
+ lora_r: loraR,
|
|
|
+ lora_alpha: loraR * 2,
|
|
|
+ })
|
|
|
+ .then(() => {
|
|
|
+ setModelId('')
|
|
|
+ setDatasetId('')
|
|
|
+ fetchJobs()
|
|
|
+ })
|
|
|
+ .catch(console.error)
|
|
|
+ .finally(() => setSubmitting(false))
|
|
|
+ }
|
|
|
+
|
|
|
+ const handleCancel = (id: string) => {
|
|
|
+ api.training.cancel(id)
|
|
|
+ .then(() => fetchJobs())
|
|
|
+ .catch(console.error)
|
|
|
+ }
|
|
|
+
|
|
|
+ const statusColor = (status: string) => {
|
|
|
+ switch (status) {
|
|
|
+ case 'completed': return '#4caf50'
|
|
|
+ case 'failed': return '#e94560'
|
|
|
+ case 'training': return '#2196f3'
|
|
|
+ case 'pending': case 'queued': return '#ff9800'
|
|
|
+ case 'cancelled': return '#999'
|
|
|
+ default: return '#666'
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
return (
|
|
|
<div>
|
|
|
<h1>训练任务</h1>
|
|
|
- <p style={{ marginTop: 16, color: '#666', fontSize: 14 }}>
|
|
|
- Phase 3 将实现完整的训练任务创建、监控和 WebSocket 实时进度推送。
|
|
|
- </p>
|
|
|
+
|
|
|
+ {/* Create form */}
|
|
|
+ <div style={{ marginTop: 16, background: '#fff', borderRadius: 8, padding: 20, boxShadow: '0 1px 3px rgba(0,0,0,0.1)' }}>
|
|
|
+ <h2 style={{ margin: '0 0 16px', fontSize: 16 }}>创建训练任务</h2>
|
|
|
+ <div style={{ display: 'grid', gridTemplateColumns: 'repeat(3, 1fr)', gap: 12 }}>
|
|
|
+ <div>
|
|
|
+ <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>模型 ID</label>
|
|
|
+ <input value={modelId} onChange={e => setModelId(e.target.value)} placeholder="模型 ID" style={{ width: '100%', padding: '6px 8px', borderRadius: 4, border: '1px solid #ccc', boxSizing: 'border-box' }} />
|
|
|
+ </div>
|
|
|
+ <div>
|
|
|
+ <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>模型类型</label>
|
|
|
+ <select value={modelType} onChange={e => setModelType(e.target.value)} style={{ width: '100%', padding: '6px 8px', borderRadius: 4, border: '1px solid #ccc', boxSizing: 'border-box' }}>
|
|
|
+ {MODEL_TYPES.map(t => <option key={t.value} value={t.value}>{t.label}</option>)}
|
|
|
+ </select>
|
|
|
+ </div>
|
|
|
+ <div>
|
|
|
+ <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>数据集 ID</label>
|
|
|
+ <input value={datasetId} onChange={e => setDatasetId(e.target.value)} placeholder="数据集 ID" style={{ width: '100%', padding: '6px 8px', borderRadius: 4, border: '1px solid #ccc', boxSizing: 'border-box' }} />
|
|
|
+ </div>
|
|
|
+ <div>
|
|
|
+ <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>PEFT 方法</label>
|
|
|
+ <select value={peftMethod} onChange={e => setPeftMethod(e.target.value)} style={{ width: '100%', padding: '6px 8px', borderRadius: 4, border: '1px solid #ccc', boxSizing: 'border-box' }}>
|
|
|
+ {PEFT_METHODS.map(m => <option key={m.value} value={m.value}>{m.label}</option>)}
|
|
|
+ </select>
|
|
|
+ </div>
|
|
|
+ <div>
|
|
|
+ <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>Epochs</label>
|
|
|
+ <input type="number" value={epochs} onChange={e => setEpochs(Number(e.target.value))} min={1} style={{ width: '100%', padding: '6px 8px', borderRadius: 4, border: '1px solid #ccc', boxSizing: 'border-box' }} />
|
|
|
+ </div>
|
|
|
+ <div>
|
|
|
+ <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>Batch Size</label>
|
|
|
+ <input type="number" value={batchSize} onChange={e => setBatchSize(Number(e.target.value))} min={1} style={{ width: '100%', padding: '6px 8px', borderRadius: 4, border: '1px solid #ccc', boxSizing: 'border-box' }} />
|
|
|
+ </div>
|
|
|
+ <div>
|
|
|
+ <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>Learning Rate</label>
|
|
|
+ <input value={lr} onChange={e => setLr(e.target.value)} style={{ width: '100%', padding: '6px 8px', borderRadius: 4, border: '1px solid #ccc', boxSizing: 'border-box' }} />
|
|
|
+ </div>
|
|
|
+ <div>
|
|
|
+ <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>LoRA R</label>
|
|
|
+ <input type="number" value={loraR} onChange={e => setLoraR(Number(e.target.value))} min={1} style={{ width: '100%', padding: '6px 8px', borderRadius: 4, border: '1px solid #ccc', boxSizing: 'border-box' }} />
|
|
|
+ </div>
|
|
|
+ </div>
|
|
|
+ <button
|
|
|
+ onClick={handleCreate}
|
|
|
+ disabled={submitting}
|
|
|
+ style={{ marginTop: 16, padding: '8px 24px', borderRadius: 4, border: 'none', background: '#e94560', color: '#fff', cursor: 'pointer', opacity: submitting ? 0.6 : 1 }}
|
|
|
+ >
|
|
|
+ {submitting ? '创建中...' : '启动训练'}
|
|
|
+ </button>
|
|
|
+ </div>
|
|
|
+
|
|
|
+ {/* Job list */}
|
|
|
+ <div style={{ marginTop: 24 }}>
|
|
|
+ <div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 12 }}>
|
|
|
+ <h2 style={{ margin: 0 }}>任务列表</h2>
|
|
|
+ <button onClick={fetchJobs} style={{ padding: '4px 12px', borderRadius: 4, border: '1px solid #ccc', background: '#fff', cursor: 'pointer' }}>刷新</button>
|
|
|
+ </div>
|
|
|
+
|
|
|
+ {loading && <p style={{ color: '#999' }}>加载中...</p>}
|
|
|
+
|
|
|
+ {!loading && jobs.length === 0 && (
|
|
|
+ <p style={{ color: '#999', fontSize: 14 }}>暂无训练任务</p>
|
|
|
+ )}
|
|
|
+
|
|
|
+ {!loading && jobs.length > 0 && (
|
|
|
+ <table style={{ width: '100%', borderCollapse: 'collapse', fontSize: 14 }}>
|
|
|
+ <thead>
|
|
|
+ <tr style={{ borderBottom: '2px solid #eee', textAlign: 'left' }}>
|
|
|
+ <th style={{ padding: '8px 0' }}>任务 ID</th>
|
|
|
+ <th>模型</th>
|
|
|
+ <th>PEFT</th>
|
|
|
+ <th>状态</th>
|
|
|
+ <th>进度</th>
|
|
|
+ <th>Loss</th>
|
|
|
+ <th>操作</th>
|
|
|
+ </tr>
|
|
|
+ </thead>
|
|
|
+ <tbody>
|
|
|
+ {jobs.map(j => (
|
|
|
+ <tr key={j.id} style={{ borderBottom: '1px solid #eee' }}>
|
|
|
+ <td style={{ padding: '8px 0', fontFamily: 'monospace', fontSize: 12 }}>{j.id.slice(0, 8)}...</td>
|
|
|
+ <td>{j.model_id}</td>
|
|
|
+ <td>{j.peft_method}</td>
|
|
|
+ <td style={{ color: statusColor(j.status), fontWeight: 600 }}>{j.status}</td>
|
|
|
+ <td>
|
|
|
+ <div style={{ width: 120, height: 6, background: '#eee', borderRadius: 3, overflow: 'hidden' }}>
|
|
|
+ <div style={{ width: `${j.progress}%`, height: '100%', background: j.status === 'failed' ? '#e94560' : '#4caf50', transition: 'width 0.3s' }} />
|
|
|
+ </div>
|
|
|
+ <span style={{ fontSize: 11, color: '#999' }}>{j.progress.toFixed(1)}%</span>
|
|
|
+ </td>
|
|
|
+ <td>{j.loss?.toFixed(4) ?? '-'}</td>
|
|
|
+ <td>
|
|
|
+ {(j.status === 'training' || j.status === 'pending' || j.status === 'queued') && (
|
|
|
+ <button onClick={() => handleCancel(j.id)} style={{ padding: '2px 8px', color: '#e94560', border: '1px solid #e94560', borderRadius: 4, background: 'transparent', cursor: 'pointer' }}>取消</button>
|
|
|
+ )}
|
|
|
+ </td>
|
|
|
+ </tr>
|
|
|
+ ))}
|
|
|
+ </tbody>
|
|
|
+ </table>
|
|
|
+ )}
|
|
|
+ </div>
|
|
|
</div>
|
|
|
)
|
|
|
}
|