Training.tsx 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667
  1. import { useState, useEffect, useRef, useCallback, memo } from 'react'
  2. import api, { TrainingJob, ModelInfo, DatasetInfo } from '../api/client'
  3. import { wsManager } from '../api/websocket'
  4. import { Train } from 'lucide-react'
  5. const MODEL_TYPES = [
  6. { value: 'text', label: '文本 (LLaMA/Qwen)' },
  7. { value: 'vision', label: '视觉 (ViT/CLIP)' },
  8. { value: 'multimodal', label: '多模态 (LLaVA/Qwen-VL)' },
  9. ]
  10. const PEFT_METHODS = [
  11. { value: 'lora', label: 'LoRA' },
  12. { value: 'qlora', label: 'QLoRA (推荐)' },
  13. { value: 'adalora', label: 'AdaLoRA' },
  14. ]
  15. const TASK_TYPES = [
  16. { value: 'sft', label: 'SFT (监督微调)' },
  17. { value: 'dpo', label: 'DPO (直接偏好优化)' },
  18. { value: 'ppo', label: 'PPO (近端策略优化)' },
  19. ]
  20. const DATASET_TEMPLATES = [
  21. { value: 'auto', label: 'Auto (自动检测)' },
  22. { value: 'alpaca', label: 'Alpaca (instruction/input/output)' },
  23. { value: 'sharegpt', label: 'ShareGPT (conversations)' },
  24. { value: 'raw', label: 'Raw (text 字段)' },
  25. ]
  26. // --- 预设值常量 ---
  27. const EPOCH_PRESETS = [
  28. { value: 1, label: '1 (快速验证)' },
  29. { value: 2, label: '2' },
  30. { value: 3, label: '3 (推荐)' },
  31. { value: 5, label: '5' },
  32. { value: 10, label: '10 (充分训练)' },
  33. ]
  34. const BATCH_SIZE_PRESETS = [
  35. { value: 1, label: '1 (显存受限)' },
  36. { value: 2, label: '2' },
  37. { value: 4, label: '4' },
  38. { value: 8, label: '8' },
  39. { value: 16, label: '16 (推荐)' },
  40. { value: 32, label: '32' },
  41. { value: 64, label: '64' },
  42. ]
  43. const LR_PRESETS = [
  44. { value: '1e-3', label: '1e-3 (较大)' },
  45. { value: '5e-4', label: '5e-4' },
  46. { value: '2e-4', label: '2e-4 (推荐)' },
  47. { value: '1e-4', label: '1e-4' },
  48. { value: '5e-5', label: '5e-5 (较小)' },
  49. { value: '1e-5', label: '1e-5' },
  50. ]
  51. const LORA_R_PRESETS = [
  52. { value: 4, label: '4 (轻量)' },
  53. { value: 8, label: '8' },
  54. { value: 16, label: '16 (推荐)' },
  55. { value: 32, label: '32' },
  56. { value: 64, label: '64 (高精度)' },
  57. ]
  58. const SEQ_LEN_PRESETS = [
  59. { value: 512, label: '512 (短文本)' },
  60. { value: 1024, label: '1024' },
  61. { value: 2048, label: '2048 (推荐)' },
  62. { value: 4096, label: '4096 (长文本)' },
  63. ]
  64. const GRAD_ACC_PRESETS = [
  65. { value: 1, label: '1 (无累积)' },
  66. { value: 2, label: '2' },
  67. { value: 4, label: '4 (推荐)' },
  68. { value: 8, label: '8' },
  69. { value: 16, label: '16' },
  70. ]
  71. const NUM_GPUS_PRESETS = [
  72. { value: 1, label: '1 GPU (单卡)' },
  73. { value: 2, label: '2 GPU (DDP 数据并行)' },
  74. ]
  75. // --- 通用 Select 组件 ---
  76. interface SelectProps {
  77. options: { value: string | number; label: string }[]
  78. value: string | number
  79. onChange: (value: string | number) => void
  80. placeholder?: string
  81. }
  82. function Select({ options, value, onChange, placeholder }: SelectProps) {
  83. return (
  84. <select
  85. value={value}
  86. onChange={e => onChange(e.target.value)}
  87. style={{
  88. width: '100%', padding: '6px 8px', borderRadius: 4,
  89. border: '1px solid #d0d0d0', boxSizing: 'border-box',
  90. fontSize: 13, background: '#fff', cursor: 'pointer',
  91. }}
  92. >
  93. {placeholder && <option value="" disabled>{placeholder}</option>}
  94. {options.map(o => (
  95. <option key={o.value} value={o.value}>{o.label}</option>
  96. ))}
  97. </select>
  98. )
  99. }
  100. // --- 可搜索下拉框组件 ---
  101. interface SearchableSelectProps {
  102. options: { value: string; label: string; subtitle?: string }[]
  103. value: string
  104. onChange: (value: string) => void
  105. placeholder: string
  106. loading?: boolean
  107. }
  108. function SearchableSelect({ options, value, onChange, placeholder, loading }: SearchableSelectProps) {
  109. const [open, setOpen] = useState(false)
  110. const [filter, setFilter] = useState('')
  111. const wrapperRef = useRef<HTMLDivElement>(null)
  112. const inputRef = useRef<HTMLInputElement>(null)
  113. useEffect(() => {
  114. const handler = (e: MouseEvent) => {
  115. if (wrapperRef.current && !wrapperRef.current.contains(e.target as Node)) {
  116. setOpen(false)
  117. }
  118. }
  119. if (open) document.addEventListener('mousedown', handler)
  120. return () => document.removeEventListener('mousedown', handler)
  121. }, [open])
  122. useEffect(() => {
  123. if (open) {
  124. setFilter('')
  125. setTimeout(() => inputRef.current?.focus(), 0)
  126. }
  127. }, [open])
  128. const selectedLabel = options.find(o => o.value === value)?.label ?? ''
  129. const filtered = options.filter(o =>
  130. o.label.toLowerCase().includes(filter.toLowerCase()) || o.value.toLowerCase().includes(filter.toLowerCase())
  131. )
  132. const handleSelect = useCallback((val: string) => {
  133. onChange(val)
  134. setOpen(false)
  135. }, [onChange])
  136. const handleKeyDown = useCallback((e: React.KeyboardEvent) => {
  137. if (e.key === 'Enter' && filtered.length === 1) {
  138. handleSelect(filtered[0].value)
  139. } else if (e.key === 'Escape') {
  140. setOpen(false)
  141. }
  142. }, [filtered, handleSelect])
  143. const toggleOpen = useCallback(() => setOpen(prev => !prev), [])
  144. return (
  145. <div ref={wrapperRef} style={{ position: 'relative' }}>
  146. <div
  147. onClick={toggleOpen}
  148. style={{
  149. padding: '6px 8px', borderRadius: 4,
  150. border: '1px solid #d0d0d0', cursor: 'pointer', background: '#fff',
  151. minHeight: 32, display: 'flex', alignItems: 'center', justifyContent: 'space-between',
  152. fontSize: 13, transition: 'border-color 0.2s',
  153. }}
  154. >
  155. <span style={{ color: value ? '#333' : '#999' }}>
  156. {value ? selectedLabel : placeholder}
  157. </span>
  158. <span style={{ color: '#999', fontSize: 11 }}>{open ? '▲' : '▼'}</span>
  159. </div>
  160. {open && (
  161. <div style={{
  162. position: 'absolute', top: '100%', left: 0, right: 0, background: '#fff',
  163. border: '1px solid #d0d0d0', borderRadius: 4, boxShadow: '0 4px 12px rgba(0,0,0,0.12)',
  164. zIndex: 1000, marginTop: 2, maxHeight: 240, display: 'flex', flexDirection: 'column',
  165. }}>
  166. <input
  167. ref={inputRef}
  168. value={filter}
  169. onChange={e => setFilter(e.target.value)}
  170. onKeyDown={handleKeyDown}
  171. placeholder="搜索..."
  172. style={{
  173. padding: '6px 8px', border: 'none', borderBottom: '1px solid #eee',
  174. outline: 'none', fontSize: 13,
  175. }}
  176. />
  177. <div style={{ overflowY: 'auto', flex: 1 }}>
  178. {loading && (
  179. <div style={{ padding: '8px 12px', color: '#999', fontSize: 13 }}>加载中...</div>
  180. )}
  181. {!loading && filtered.length === 0 && (
  182. <div style={{ padding: '8px 12px', color: '#999', fontSize: 13 }}>无匹配项</div>
  183. )}
  184. {!loading && filtered.map(opt => (
  185. <div
  186. key={opt.value}
  187. onClick={() => handleSelect(opt.value)}
  188. style={{
  189. padding: '8px 12px', cursor: 'pointer',
  190. background: opt.value === value ? '#14b8a6' : 'transparent',
  191. color: opt.value === value ? '#fff' : '#333',
  192. fontSize: 13,
  193. }}
  194. onMouseEnter={e => { if (opt.value !== value) e.currentTarget.style.background = '#f5f5f5' }}
  195. onMouseLeave={e => { if (opt.value !== value) e.currentTarget.style.background = 'transparent' }}
  196. >
  197. <div>{opt.label}</div>
  198. {opt.subtitle && (
  199. <div style={{ fontSize: 11, color: opt.value === value ? 'rgba(255,255,255,0.7)' : '#999', marginTop: 2 }}>
  200. {opt.subtitle}
  201. </div>
  202. )}
  203. </div>
  204. ))}
  205. </div>
  206. </div>
  207. )}
  208. </div>
  209. )
  210. }
  211. // --- 任务状态颜色 ---
  212. const statusColor = (status: string) => {
  213. switch (status) {
  214. case 'completed': return '#10b981'
  215. case 'failed': return '#f43f5e'
  216. case 'training': return '#0ea5e9'
  217. case 'pending': case 'queued': return '#f59e0b'
  218. case 'preprocessing': return '#a855f7'
  219. case 'cancelled': return '#94a3b8'
  220. default: return '#64748b'
  221. }
  222. }
  223. const statusLabel = (status: string) => {
  224. switch (status) {
  225. case 'completed': return '已完成'
  226. case 'failed': return '失败'
  227. case 'training': return '训练中'
  228. case 'pending': return '等待中'
  229. case 'queued': return '排队中'
  230. case 'preprocessing': return '预处理'
  231. case 'cancelled': return '已取消'
  232. default: return status
  233. }
  234. }
  235. // --- 任务行(memo) ---
  236. const JobRow = memo(function JobRow({ j, onCancel, datasets }: { j: TrainingJob; onCancel: (id: string) => void; datasets: DatasetInfo[] }) {
  237. const modelShort = j.model_id.split('/').pop() || j.model_id
  238. const dsName = datasets?.find(d => d.id === j.dataset_id)?.name
  239. || j.dataset_id?.split('/').pop()
  240. || '-'
  241. return (
  242. <tr style={{
  243. borderBottom: '1px solid #f0f0f0',
  244. transition: 'background 0.15s ease',
  245. }}
  246. onMouseEnter={e => { e.currentTarget.style.background = '#f0fdfa' }}
  247. onMouseLeave={e => { e.currentTarget.style.background = 'transparent' }}
  248. >
  249. <td style={{ padding: '12px 12px' }}>
  250. <div style={{ fontSize: 13, fontWeight: 500, color: '#1e293b' }}>{modelShort}</div>
  251. <div style={{ fontFamily: 'monospace', fontSize: 11, color: '#94a3b8', marginTop: 2 }}>{j.id.slice(0, 8)}...</div>
  252. </td>
  253. <td style={{ padding: '12px 12px', fontSize: 13, textTransform: 'uppercase', color: '#666' }}>{j.peft_method}</td>
  254. <td style={{ padding: '12px 12px', fontSize: 13, color: '#64748b', maxWidth: 160, overflow: 'hidden', textOverflow: 'ellipsis', whiteSpace: 'nowrap' }} title={j.dataset_id}>{dsName}</td>
  255. <td style={{ padding: '12px 12px' }}>
  256. <span style={{
  257. display: 'inline-block', padding: '3px 10px', borderRadius: 12, fontSize: 12, fontWeight: 600,
  258. background: statusColor(j.status) + '15', color: statusColor(j.status),
  259. }}>
  260. {statusLabel(j.status)}
  261. </span>
  262. </td>
  263. <td style={{ padding: '12px 12px' }}>
  264. <div style={{ display: 'flex', alignItems: 'center', gap: 10 }}>
  265. <div style={{
  266. width: 120, height: 8, background: '#f0f0f0', borderRadius: 4, overflow: 'hidden',
  267. }}>
  268. <div style={{
  269. width: `${Math.min(100, Math.max(0, j.progress ?? 0))}%`, height: '100%', borderRadius: 4,
  270. background: `linear-gradient(90deg, ${statusColor(j.status)}, ${statusColor(j.status)}cc)`,
  271. transition: 'width 0.3s ease',
  272. }} />
  273. </div>
  274. <span style={{ fontSize: 12, color: '#666', minWidth: 45, fontWeight: 500 }}>{(j.progress ?? 0).toFixed(1)}%</span>
  275. </div>
  276. </td>
  277. <td style={{ padding: '12px 12px', fontSize: 13, fontFamily: 'monospace', fontWeight: 500 }}>{j.loss?.toFixed(4) ?? '-'}</td>
  278. <td style={{ padding: '12px 12px', fontSize: 12, color: '#888' }}>Epoch {j.current_epoch}</td>
  279. <td style={{ padding: '12px 12px' }}>
  280. {(j.status === 'training' || j.status === 'pending' || j.status === 'queued' || j.status === 'preprocessing') && (
  281. <button onClick={() => onCancel(j.id)} style={{
  282. padding: '4px 12px', color: '#f43f5e', border: '1px solid #f43f5e',
  283. borderRadius: 6, background: 'transparent', cursor: 'pointer',
  284. fontSize: 12, fontWeight: 500, transition: 'all 0.15s ease',
  285. }}
  286. onMouseEnter={e => { e.currentTarget.style.background = '#f43f5e'; e.currentTarget.style.color = '#fff' }}
  287. onMouseLeave={e => { e.currentTarget.style.background = 'transparent'; e.currentTarget.style.color = '#f43f5e' }}
  288. >取消</button>
  289. )}
  290. </td>
  291. </tr>
  292. )
  293. })
  294. export function Training() {
  295. const [modelId, setModelId] = useState('')
  296. const [modelType, setModelType] = useState('text')
  297. const [datasetId, setDatasetId] = useState('')
  298. const [peftMethod, setPeftMethod] = useState('lora')
  299. const [taskType, setTaskType] = useState('sft')
  300. const [template, setTemplate] = useState('auto')
  301. const [epochs, setEpochs] = useState(3)
  302. const [batchSize, setBatchSize] = useState(16)
  303. const [lr, setLr] = useState('2e-4')
  304. const [loraR, setLoraR] = useState(16)
  305. const [seqLen, setSeqLen] = useState(2048)
  306. const [gradAcc, setGradAcc] = useState(4)
  307. const [numGpus, setNumGpus] = useState(1)
  308. const [jobs, setJobs] = useState<TrainingJob[]>([])
  309. const [loading, setLoading] = useState(false)
  310. const [submitting, setSubmitting] = useState(false)
  311. const [createError, setCreateError] = useState('')
  312. const [models, setModels] = useState<ModelInfo[]>([])
  313. const [datasets, setDatasets] = useState<DatasetInfo[]>([])
  314. const [loadingOptions, setLoadingOptions] = useState(true)
  315. const fetchOptions = useCallback(() => {
  316. setLoadingOptions(true)
  317. Promise.all([
  318. api.models.list().catch(() => []),
  319. api.datasets.list().catch(() => []),
  320. ]).then(([m, d]) => {
  321. setModels(m)
  322. setDatasets(d)
  323. }).finally(() => setLoadingOptions(false))
  324. }, [])
  325. useEffect(() => { fetchOptions() }, [fetchOptions])
  326. const jobsRef = useRef<TrainingJob[]>([])
  327. // 跟踪已建立 WS 连接的 job 和对应的取消订阅函数
  328. const wsConnectedRef = useRef<Set<string>>(new Set())
  329. const wsUnsubsRef = useRef<Map<string, () => void>>(new Map())
  330. const fetchJobs = () => {
  331. setLoading(true)
  332. api.training.list()
  333. .then(newJobs => {
  334. const prev = jobsRef.current
  335. if (JSON.stringify(prev) !== JSON.stringify(newJobs)) {
  336. setJobs(newJobs)
  337. jobsRef.current = newJobs
  338. }
  339. syncWsConnections(newJobs)
  340. })
  341. .catch(() => {
  342. if (jobsRef.current.length > 0) {
  343. setJobs([])
  344. jobsRef.current = []
  345. }
  346. syncWsConnections([])
  347. })
  348. .finally(() => setLoading(false))
  349. }
  350. // 根据 job 状态同步 WebSocket 连接:training/preprocessing 连接,其他断开
  351. const ACTIVE_STATUSES = new Set(['training', 'preprocessing'])
  352. function syncWsConnections(currentJobs: TrainingJob[]) {
  353. const activeIds = new Set(currentJobs.filter(j => ACTIVE_STATUSES.has(j.status)).map(j => j.id))
  354. // 为新的活跃 job 建立连接并订阅
  355. for (const jobId of activeIds) {
  356. if (wsConnectedRef.current.has(jobId)) continue
  357. wsConnectedRef.current.add(jobId)
  358. wsManager.connect(jobId)
  359. const unsub = wsManager.subscribe(jobId, (msg) => {
  360. if (msg.type === 'progress') {
  361. setJobs(prev => prev.map(j =>
  362. j.id === jobId
  363. ? {
  364. ...j,
  365. current_step: (msg.step as number) ?? j.current_step,
  366. current_epoch: (msg.epoch as number) ?? j.current_epoch,
  367. total_steps: (msg.total_steps as number) || j.total_steps,
  368. loss: (msg.loss as number) ?? j.loss,
  369. progress: j.total_steps
  370. ? (((msg.step as number) ?? j.current_step) / (j.total_steps || 1)) * 100
  371. : j.progress,
  372. }
  373. : j
  374. ))
  375. } else if (msg.type === 'completed') {
  376. setJobs(prev => prev.map(j =>
  377. j.id === jobId ? { ...j, status: 'completed', progress: 100 } : j
  378. ))
  379. cleanupJobWs(jobId)
  380. } else if (msg.type === 'error') {
  381. setJobs(prev => prev.map(j =>
  382. j.id === jobId
  383. ? { ...j, status: 'failed', error_message: (msg.message as string) ?? '训练失败' }
  384. : j
  385. ))
  386. cleanupJobWs(jobId)
  387. }
  388. })
  389. wsUnsubsRef.current.set(jobId, unsub)
  390. }
  391. // 断开已完成/失败的 job 的连接
  392. for (const jobId of [...wsConnectedRef.current]) {
  393. if (!activeIds.has(jobId)) {
  394. cleanupJobWs(jobId)
  395. }
  396. }
  397. }
  398. function cleanupJobWs(jobId: string) {
  399. wsUnsubsRef.current.get(jobId)?.()
  400. wsUnsubsRef.current.delete(jobId)
  401. wsConnectedRef.current.delete(jobId)
  402. wsManager.disconnect(jobId)
  403. }
  404. useEffect(() => {
  405. fetchJobs()
  406. const interval = setInterval(fetchJobs, 5000)
  407. return () => {
  408. clearInterval(interval)
  409. wsManager.disconnectAll()
  410. }
  411. }, [])
  412. const handleCreate = () => {
  413. if (!modelId.trim() || !datasetId.trim()) return
  414. setSubmitting(true)
  415. setCreateError('')
  416. const tempId = 'temp-' + Date.now()
  417. const tempJob: TrainingJob = {
  418. id: tempId, model_id: modelId, model_type: modelType,
  419. peft_method: peftMethod, status: 'pending', progress: 0,
  420. loss: undefined, created_at: new Date().toISOString(),
  421. started_at: undefined, finished_at: undefined,
  422. error_message: undefined, adapter_path: undefined,
  423. current_epoch: 0, current_step: 0, total_steps: 0,
  424. }
  425. setJobs(prev => [tempJob, ...prev])
  426. setLoading(false)
  427. api.training.create({
  428. model_id: modelId, model_type: modelType, dataset_id: datasetId,
  429. peft_method: peftMethod, task_type: taskType, dataset_template: template,
  430. epochs, batch_size: batchSize, gradient_accumulation: gradAcc,
  431. max_seq_length: seqLen, learning_rate: parseFloat(lr),
  432. lora_r: loraR, lora_alpha: loraR * 2, num_gpus: numGpus,
  433. })
  434. .then(() => {
  435. setModelId('')
  436. setDatasetId('')
  437. setJobs(prev => prev.filter(j => j.id !== tempId))
  438. fetchJobs()
  439. fetchOptions()
  440. })
  441. .catch(err => {
  442. setJobs(prev => prev.filter(j => j.id !== tempId))
  443. setCreateError(err instanceof Error ? err.message : '创建失败')
  444. })
  445. .finally(() => setSubmitting(false))
  446. }
  447. const handleCancel = (id: string) => {
  448. api.training.cancel(id).then(() => fetchJobs()).catch(console.error)
  449. }
  450. const modelOptions = models.map(m => ({
  451. value: m.id, label: m.id, subtitle: `${m.model_type}${m.is_downloaded ? ' ✓ 已下载' : ''}`,
  452. }))
  453. const datasetOptions = datasets.map(d => ({
  454. value: d.id, label: d.name, subtitle: `${d.format} · ${d.record_count} 条`,
  455. }))
  456. return (
  457. <div>
  458. <h1 style={{ margin: 0, fontSize: 22, fontWeight: 700 }}>训练任务</h1>
  459. <p style={{ color: '#888', fontSize: 13, margin: '4px 0 16px' }}>创建和管理模型微调任务</p>
  460. {/* Create form */}
  461. <div style={{
  462. background: '#fff', borderRadius: 12, padding: 24,
  463. boxShadow: '0 1px 3px rgba(0,0,0,0.06)', border: '1px solid rgba(0,0,0,0.04)',
  464. }}>
  465. <h2 style={{ margin: '0 0 20px', fontSize: 15, fontWeight: 600 }}>创建训练任务</h2>
  466. {/* 核心配置 */}
  467. <div style={{
  468. fontSize: 13, fontWeight: 600, color: '#14b8a6', marginBottom: 12,
  469. paddingBottom: 6, borderBottom: '2px solid #ccfbf1',
  470. }}>核心配置</div>
  471. <div style={{ display: 'grid', gridTemplateColumns: 'repeat(3, 1fr)', gap: 16, marginBottom: 20 }}>
  472. <div>
  473. <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>基础模型</label>
  474. <SearchableSelect options={modelOptions} value={modelId} onChange={setModelId} placeholder="选择已下载的模型" loading={loadingOptions} />
  475. </div>
  476. <div>
  477. <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>模型类型</label>
  478. <Select options={MODEL_TYPES} value={modelType} onChange={v => setModelType(String(v))} />
  479. </div>
  480. <div>
  481. <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>训练数据集</label>
  482. <SearchableSelect options={datasetOptions} value={datasetId} onChange={setDatasetId} placeholder="选择数据集" loading={loadingOptions} />
  483. </div>
  484. <div>
  485. <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>训练方法</label>
  486. <Select options={TASK_TYPES} value={taskType} onChange={v => setTaskType(String(v))} />
  487. </div>
  488. <div>
  489. <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>数据模板</label>
  490. <Select options={DATASET_TEMPLATES} value={template} onChange={v => setTemplate(String(v))} />
  491. </div>
  492. <div>
  493. <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>PEFT 方法</label>
  494. <Select options={PEFT_METHODS} value={peftMethod} onChange={v => setPeftMethod(String(v))} />
  495. </div>
  496. </div>
  497. {/* 训练超参 */}
  498. <div style={{
  499. fontSize: 13, fontWeight: 600, color: '#14b8a6', marginBottom: 12,
  500. paddingBottom: 6, borderBottom: '2px solid #ccfbf1',
  501. }}>训练超参数</div>
  502. <div style={{ display: 'grid', gridTemplateColumns: 'repeat(3, 1fr)', gap: 16, marginBottom: 20 }}>
  503. <div>
  504. <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>训练轮数 (Epochs)</label>
  505. <Select options={EPOCH_PRESETS} value={String(epochs)} onChange={v => setEpochs(Number(v))} />
  506. </div>
  507. <div>
  508. <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>批次大小 (Batch Size)</label>
  509. <Select options={BATCH_SIZE_PRESETS} value={String(batchSize)} onChange={v => setBatchSize(Number(v))} />
  510. </div>
  511. <div>
  512. <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>梯度累积</label>
  513. <Select options={GRAD_ACC_PRESETS} value={String(gradAcc)} onChange={v => setGradAcc(Number(v))} />
  514. </div>
  515. <div>
  516. <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>
  517. GPU 数量 {numGpus > 1 && <span style={{ color: '#2563eb', fontSize: 11 }}>(每卡 batch={batchSize})</span>}
  518. </label>
  519. <Select options={NUM_GPUS_PRESETS} value={String(numGpus)} onChange={v => setNumGpus(Number(v))} />
  520. </div>
  521. <div>
  522. <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>学习率</label>
  523. <Select options={LR_PRESETS} value={lr} onChange={v => setLr(String(v))} />
  524. </div>
  525. <div>
  526. <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>最大序列长度</label>
  527. <Select options={SEQ_LEN_PRESETS} value={String(seqLen)} onChange={v => setSeqLen(Number(v))} />
  528. </div>
  529. <div>
  530. <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>LoRA Rank (R)</label>
  531. <Select options={LORA_R_PRESETS} value={String(loraR)} onChange={v => setLoraR(Number(v))} />
  532. </div>
  533. </div>
  534. {/* 高级选项 — DeepSpeed 暂不支持(沐曦 GPU 兼容性待验证,使用 DDP 替代) */}
  535. {/* 错误提示 */}
  536. {createError && (
  537. <div style={{
  538. marginBottom: 16, padding: 12, background: '#fff1f2', borderRadius: 8,
  539. fontSize: 13, color: '#e11d48', border: '1px solid #fecdd3',
  540. }}>
  541. {createError}
  542. </div>
  543. )}
  544. <button
  545. onClick={handleCreate}
  546. disabled={submitting || !modelId || !datasetId}
  547. style={{
  548. padding: '12px 36px', borderRadius: 8, border: 'none',
  549. background: '#14b8a6', color: '#fff', cursor: 'pointer',
  550. opacity: (submitting || !modelId || !datasetId) ? 0.5 : 1,
  551. fontSize: 14, fontWeight: 600, transition: 'all 0.2s ease',
  552. }}
  553. >
  554. {submitting ? '创建中...' : '启动训练'}
  555. </button>
  556. </div>
  557. {/* Job list */}
  558. <div style={{ marginTop: 24 }}>
  559. <div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 12 }}>
  560. <h2 style={{ margin: 0, fontSize: 15, fontWeight: 600 }}>任务列表</h2>
  561. <button onClick={fetchJobs} style={{
  562. padding: '6px 14px', borderRadius: 6, border: '1px solid #d0d0d0',
  563. background: '#fff', cursor: 'pointer', fontSize: 13, fontWeight: 500,
  564. }}
  565. onMouseEnter={e => { e.currentTarget.style.background = '#f5f5f5' }}
  566. onMouseLeave={e => { e.currentTarget.style.background = '#fff' }}
  567. >
  568. 刷新
  569. </button>
  570. </div>
  571. {loading && <p style={{ color: '#999', fontSize: 13 }}>加载中...</p>}
  572. {!loading && jobs.length === 0 && (
  573. <div style={{
  574. padding: 40, textAlign: 'center', color: '#94a3b8', fontSize: 14,
  575. background: '#fff', borderRadius: 10, boxShadow: '0 1px 3px rgba(0,0,0,0.06)',
  576. }}>
  577. <div style={{ marginBottom: 8 }}><Train size={32} color="#94a3b8" strokeWidth={1.5} /></div>
  578. 暂无训练任务,请先创建训练任务
  579. </div>
  580. )}
  581. {!loading && jobs.length > 0 && (
  582. <div style={{
  583. background: '#fff', borderRadius: 10, overflow: 'hidden',
  584. boxShadow: '0 1px 3px rgba(0,0,0,0.06)', border: '1px solid rgba(0,0,0,0.04)',
  585. }}>
  586. <table style={{ width: '100%', borderCollapse: 'collapse', fontSize: 13 }}>
  587. <thead>
  588. <tr style={{ background: '#f0fdfa', borderBottom: '2px solid #f1f5f9', textAlign: 'left' }}>
  589. <th style={{ padding: '10px 12px', fontSize: 12, color: '#666', fontWeight: 600 }}>任务</th>
  590. <th style={{ padding: '10px 12px', fontSize: 12, color: '#666', fontWeight: 600 }}>PEFT</th>
  591. <th style={{ padding: '10px 12px', fontSize: 12, color: '#666', fontWeight: 600 }}>数据集</th>
  592. <th style={{ padding: '10px 12px', fontSize: 12, color: '#666', fontWeight: 600 }}>状态</th>
  593. <th style={{ padding: '10px 12px', fontSize: 12, color: '#666', fontWeight: 600 }}>进度</th>
  594. <th style={{ padding: '10px 12px', fontSize: 12, color: '#666', fontWeight: 600 }}>Loss</th>
  595. <th style={{ padding: '10px 12px', fontSize: 12, color: '#666', fontWeight: 600 }}>轮次</th>
  596. <th style={{ padding: '10px 12px', fontSize: 12, color: '#666', fontWeight: 600 }}>操作</th>
  597. </tr>
  598. </thead>
  599. <tbody>
  600. {jobs.map(j => (
  601. <JobRow key={j.id} j={j} onCancel={handleCancel} datasets={datasets} />
  602. ))}
  603. </tbody>
  604. </table>
  605. </div>
  606. )}
  607. </div>
  608. </div>
  609. )
  610. }