|
|
@@ -1,5 +1,5 @@
|
|
|
-import { useState, useEffect } from 'react'
|
|
|
-import api, { TrainingJob } from '../api/client'
|
|
|
+import { useState, useEffect, useRef, useCallback } from 'react'
|
|
|
+import api, { TrainingJob, ModelInfo, DatasetInfo } from '../api/client'
|
|
|
import { wsManager } from '../api/websocket'
|
|
|
|
|
|
const MODEL_TYPES = [
|
|
|
@@ -31,6 +31,153 @@ const DATASET_TEMPLATES = [
|
|
|
{ value: 'raw', label: 'Raw (直接字段)' },
|
|
|
]
|
|
|
|
|
|
+// --- 可搜索下拉框组件 ---
|
|
|
+interface SearchableSelectProps {
|
|
|
+ options: { value: string; label: string; subtitle?: string }[]
|
|
|
+ value: string
|
|
|
+ onChange: (value: string) => void
|
|
|
+ placeholder: string
|
|
|
+ loading?: boolean
|
|
|
+}
|
|
|
+
|
|
|
+function SearchableSelect({ options, value, onChange, placeholder, loading }: SearchableSelectProps) {
|
|
|
+ const [open, setOpen] = useState(false)
|
|
|
+ const [filter, setFilter] = useState('')
|
|
|
+ const wrapperRef = useRef<HTMLDivElement>(null)
|
|
|
+ const inputRef = useRef<HTMLInputElement>(null)
|
|
|
+
|
|
|
+ // 点击外部关闭
|
|
|
+ useEffect(() => {
|
|
|
+ const handler = (e: MouseEvent) => {
|
|
|
+ if (wrapperRef.current && !wrapperRef.current.contains(e.target as Node)) {
|
|
|
+ setOpen(false)
|
|
|
+ }
|
|
|
+ }
|
|
|
+ if (open) document.addEventListener('mousedown', handler)
|
|
|
+ return () => document.removeEventListener('mousedown', handler)
|
|
|
+ }, [open])
|
|
|
+
|
|
|
+ // 打开时自动聚焦输入框
|
|
|
+ useEffect(() => {
|
|
|
+ if (open) {
|
|
|
+ setFilter('')
|
|
|
+ setTimeout(() => inputRef.current?.focus(), 0)
|
|
|
+ }
|
|
|
+ }, [open])
|
|
|
+
|
|
|
+ // 当前选中的 label
|
|
|
+ const selectedLabel = options.find(o => o.value === value)?.label ?? ''
|
|
|
+
|
|
|
+ const filtered = options.filter(o =>
|
|
|
+ o.label.toLowerCase().includes(filter.toLowerCase()) || o.value.toLowerCase().includes(filter.toLowerCase())
|
|
|
+ )
|
|
|
+
|
|
|
+ const handleKeyDown = (e: React.KeyboardEvent) => {
|
|
|
+ if (e.key === 'Enter' && filtered.length === 1) {
|
|
|
+ onChange(filtered[0].value)
|
|
|
+ setOpen(false)
|
|
|
+ } else if (e.key === 'Escape') {
|
|
|
+ setOpen(false)
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ return (
|
|
|
+ <div ref={wrapperRef} style={{ position: 'relative' }}>
|
|
|
+ {/* 显示框 */}
|
|
|
+ <div
|
|
|
+ onClick={() => setOpen(!open)}
|
|
|
+ style={{
|
|
|
+ padding: '6px 8px',
|
|
|
+ borderRadius: 4,
|
|
|
+ border: '1px solid #ccc',
|
|
|
+ cursor: 'pointer',
|
|
|
+ background: '#fff',
|
|
|
+ minHeight: 32,
|
|
|
+ display: 'flex',
|
|
|
+ alignItems: 'center',
|
|
|
+ justifyContent: 'space-between',
|
|
|
+ fontSize: 14,
|
|
|
+ }}
|
|
|
+ >
|
|
|
+ <span style={{ color: value ? '#333' : '#999' }}>
|
|
|
+ {value ? selectedLabel : placeholder}
|
|
|
+ </span>
|
|
|
+ <span style={{ color: '#999', fontSize: 12 }}>{open ? '▲' : '▼'}</span>
|
|
|
+ </div>
|
|
|
+
|
|
|
+ {/* 下拉列表 */}
|
|
|
+ {open && (
|
|
|
+ <div style={{
|
|
|
+ position: 'absolute',
|
|
|
+ top: '100%',
|
|
|
+ left: 0,
|
|
|
+ right: 0,
|
|
|
+ background: '#fff',
|
|
|
+ border: '1px solid #ccc',
|
|
|
+ borderRadius: 4,
|
|
|
+ boxShadow: '0 4px 12px rgba(0,0,0,0.15)',
|
|
|
+ zIndex: 1000,
|
|
|
+ marginTop: 2,
|
|
|
+ maxHeight: 240,
|
|
|
+ display: 'flex',
|
|
|
+ flexDirection: 'column',
|
|
|
+ }}>
|
|
|
+ {/* 搜索输入 */}
|
|
|
+ <input
|
|
|
+ ref={inputRef}
|
|
|
+ value={filter}
|
|
|
+ onChange={e => setFilter(e.target.value)}
|
|
|
+ onKeyDown={handleKeyDown}
|
|
|
+ placeholder="搜索..."
|
|
|
+ style={{
|
|
|
+ padding: '6px 8px',
|
|
|
+ border: 'none',
|
|
|
+ borderBottom: '1px solid #eee',
|
|
|
+ outline: 'none',
|
|
|
+ fontSize: 13,
|
|
|
+ }}
|
|
|
+ />
|
|
|
+ {/* 选项列表 */}
|
|
|
+ <div style={{ overflowY: 'auto', flex: 1 }}>
|
|
|
+ {loading && (
|
|
|
+ <div style={{ padding: '8px 12px', color: '#999', fontSize: 13 }}>加载中...</div>
|
|
|
+ )}
|
|
|
+ {!loading && filtered.length === 0 && (
|
|
|
+ <div style={{ padding: '8px 12px', color: '#999', fontSize: 13 }}>无匹配项</div>
|
|
|
+ )}
|
|
|
+ {!loading && filtered.map(opt => (
|
|
|
+ <div
|
|
|
+ key={opt.value}
|
|
|
+ onClick={() => { onChange(opt.value); setOpen(false) }}
|
|
|
+ style={{
|
|
|
+ padding: '8px 12px',
|
|
|
+ cursor: 'pointer',
|
|
|
+ background: opt.value === value ? '#e94560' : 'transparent',
|
|
|
+ color: opt.value === value ? '#fff' : '#333',
|
|
|
+ fontSize: 13,
|
|
|
+ }}
|
|
|
+ onMouseEnter={e => {
|
|
|
+ if (opt.value !== value) (e.currentTarget.style.background = '#f5f5f5')
|
|
|
+ }}
|
|
|
+ onMouseLeave={e => {
|
|
|
+ if (opt.value !== value) (e.currentTarget.style.background = 'transparent')
|
|
|
+ }}
|
|
|
+ >
|
|
|
+ <div>{opt.label}</div>
|
|
|
+ {opt.subtitle && (
|
|
|
+ <div style={{ fontSize: 11, color: opt.value === value ? 'rgba(255,255,255,0.7)' : '#999', marginTop: 2 }}>
|
|
|
+ {opt.subtitle}
|
|
|
+ </div>
|
|
|
+ )}
|
|
|
+ </div>
|
|
|
+ ))}
|
|
|
+ </div>
|
|
|
+ </div>
|
|
|
+ )}
|
|
|
+ </div>
|
|
|
+ )
|
|
|
+}
|
|
|
+
|
|
|
export function Training() {
|
|
|
const [modelId, setModelId] = useState('')
|
|
|
const [modelType, setModelType] = useState('text')
|
|
|
@@ -48,6 +195,27 @@ export function Training() {
|
|
|
const [loading, setLoading] = useState(false)
|
|
|
const [submitting, setSubmitting] = useState(false)
|
|
|
|
|
|
+ // 模型和数据集列表
|
|
|
+ const [models, setModels] = useState<ModelInfo[]>([])
|
|
|
+ const [datasets, setDatasets] = useState<DatasetInfo[]>([])
|
|
|
+ const [loadingOptions, setLoadingOptions] = useState(true)
|
|
|
+
|
|
|
+ const fetchOptions = useCallback(() => {
|
|
|
+ setLoadingOptions(true)
|
|
|
+ Promise.all([
|
|
|
+ api.models.list().catch(() => []),
|
|
|
+ api.datasets.list().catch(() => []),
|
|
|
+ ]).then(([m, d]) => {
|
|
|
+ setModels(m)
|
|
|
+ setDatasets(d)
|
|
|
+ }).finally(() => setLoadingOptions(false))
|
|
|
+ }, [])
|
|
|
+
|
|
|
+ // 页面加载时获取选项
|
|
|
+ useEffect(() => {
|
|
|
+ fetchOptions()
|
|
|
+ }, [fetchOptions])
|
|
|
+
|
|
|
// Connect WebSocket on mount
|
|
|
useEffect(() => {
|
|
|
wsManager.connect()
|
|
|
@@ -64,7 +232,6 @@ export function Training() {
|
|
|
|
|
|
useEffect(() => {
|
|
|
fetchJobs()
|
|
|
- // 每 5 秒轮询一次更新状态
|
|
|
const interval = setInterval(fetchJobs, 5000)
|
|
|
return () => clearInterval(interval)
|
|
|
}, [])
|
|
|
@@ -90,6 +257,7 @@ export function Training() {
|
|
|
setModelId('')
|
|
|
setDatasetId('')
|
|
|
fetchJobs()
|
|
|
+ fetchOptions()
|
|
|
})
|
|
|
.catch(console.error)
|
|
|
.finally(() => setSubmitting(false))
|
|
|
@@ -113,6 +281,19 @@ export function Training() {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ // 构建下拉选项
|
|
|
+ const modelOptions = models.map(m => ({
|
|
|
+ value: m.id,
|
|
|
+ label: m.id,
|
|
|
+ subtitle: `${m.model_type}${m.is_downloaded ? ' ✓' : ''}`,
|
|
|
+ }))
|
|
|
+
|
|
|
+ const datasetOptions = datasets.map(d => ({
|
|
|
+ value: d.id,
|
|
|
+ label: d.name,
|
|
|
+ subtitle: `${d.format} · ${d.record_count} 条`,
|
|
|
+ }))
|
|
|
+
|
|
|
return (
|
|
|
<div>
|
|
|
<h1>训练任务</h1>
|
|
|
@@ -122,8 +303,14 @@ export function Training() {
|
|
|
<h2 style={{ margin: '0 0 16px', fontSize: 16 }}>创建训练任务</h2>
|
|
|
<div style={{ display: 'grid', gridTemplateColumns: 'repeat(3, 1fr)', gap: 12 }}>
|
|
|
<div>
|
|
|
- <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>模型 ID</label>
|
|
|
- <input value={modelId} onChange={e => setModelId(e.target.value)} placeholder="模型 ID" style={{ width: '100%', padding: '6px 8px', borderRadius: 4, border: '1px solid #ccc', boxSizing: 'border-box' }} />
|
|
|
+ <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>模型</label>
|
|
|
+ <SearchableSelect
|
|
|
+ options={modelOptions}
|
|
|
+ value={modelId}
|
|
|
+ onChange={setModelId}
|
|
|
+ placeholder="选择模型"
|
|
|
+ loading={loadingOptions}
|
|
|
+ />
|
|
|
</div>
|
|
|
<div>
|
|
|
<label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>模型类型</label>
|
|
|
@@ -132,8 +319,14 @@ export function Training() {
|
|
|
</select>
|
|
|
</div>
|
|
|
<div>
|
|
|
- <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>数据集 ID</label>
|
|
|
- <input value={datasetId} onChange={e => setDatasetId(e.target.value)} placeholder="数据集 ID" style={{ width: '100%', padding: '6px 8px', borderRadius: 4, border: '1px solid #ccc', boxSizing: 'border-box' }} />
|
|
|
+ <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>数据集</label>
|
|
|
+ <SearchableSelect
|
|
|
+ options={datasetOptions}
|
|
|
+ value={datasetId}
|
|
|
+ onChange={setDatasetId}
|
|
|
+ placeholder="选择数据集"
|
|
|
+ loading={loadingOptions}
|
|
|
+ />
|
|
|
</div>
|
|
|
<div>
|
|
|
<label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>训练类型</label>
|