Pārlūkot izejas kodu

支持模型训练查找模型按下拉按钮选择

lxylxy123321 1 nedēļu atpakaļ
vecāks
revīzija
4ac68efdbf
1 mainītis faili ar 200 papildinājumiem un 7 dzēšanām
  1. 200 7
      frontend/src/pages/Training.tsx

+ 200 - 7
frontend/src/pages/Training.tsx

@@ -1,5 +1,5 @@
-import { useState, useEffect } from 'react'
-import api, { TrainingJob } from '../api/client'
+import { useState, useEffect, useRef, useCallback } from 'react'
+import api, { TrainingJob, ModelInfo, DatasetInfo } from '../api/client'
 import { wsManager } from '../api/websocket'
 
 const MODEL_TYPES = [
@@ -31,6 +31,153 @@ const DATASET_TEMPLATES = [
   { value: 'raw', label: 'Raw (直接字段)' },
 ]
 
+// --- 可搜索下拉框组件 ---
+interface SearchableSelectProps {
+  options: { value: string; label: string; subtitle?: string }[]
+  value: string
+  onChange: (value: string) => void
+  placeholder: string
+  loading?: boolean
+}
+
+function SearchableSelect({ options, value, onChange, placeholder, loading }: SearchableSelectProps) {
+  const [open, setOpen] = useState(false)
+  const [filter, setFilter] = useState('')
+  const wrapperRef = useRef<HTMLDivElement>(null)
+  const inputRef = useRef<HTMLInputElement>(null)
+
+  // 点击外部关闭
+  useEffect(() => {
+    const handler = (e: MouseEvent) => {
+      if (wrapperRef.current && !wrapperRef.current.contains(e.target as Node)) {
+        setOpen(false)
+      }
+    }
+    if (open) document.addEventListener('mousedown', handler)
+    return () => document.removeEventListener('mousedown', handler)
+  }, [open])
+
+  // 打开时自动聚焦输入框
+  useEffect(() => {
+    if (open) {
+      setFilter('')
+      setTimeout(() => inputRef.current?.focus(), 0)
+    }
+  }, [open])
+
+  // 当前选中的 label
+  const selectedLabel = options.find(o => o.value === value)?.label ?? ''
+
+  const filtered = options.filter(o =>
+    o.label.toLowerCase().includes(filter.toLowerCase()) || o.value.toLowerCase().includes(filter.toLowerCase())
+  )
+
+  const handleKeyDown = (e: React.KeyboardEvent) => {
+    if (e.key === 'Enter' && filtered.length === 1) {
+      onChange(filtered[0].value)
+      setOpen(false)
+    } else if (e.key === 'Escape') {
+      setOpen(false)
+    }
+  }
+
+  return (
+    <div ref={wrapperRef} style={{ position: 'relative' }}>
+      {/* 显示框 */}
+      <div
+        onClick={() => setOpen(!open)}
+        style={{
+          padding: '6px 8px',
+          borderRadius: 4,
+          border: '1px solid #ccc',
+          cursor: 'pointer',
+          background: '#fff',
+          minHeight: 32,
+          display: 'flex',
+          alignItems: 'center',
+          justifyContent: 'space-between',
+          fontSize: 14,
+        }}
+      >
+        <span style={{ color: value ? '#333' : '#999' }}>
+          {value ? selectedLabel : placeholder}
+        </span>
+        <span style={{ color: '#999', fontSize: 12 }}>{open ? '▲' : '▼'}</span>
+      </div>
+
+      {/* 下拉列表 */}
+      {open && (
+        <div style={{
+          position: 'absolute',
+          top: '100%',
+          left: 0,
+          right: 0,
+          background: '#fff',
+          border: '1px solid #ccc',
+          borderRadius: 4,
+          boxShadow: '0 4px 12px rgba(0,0,0,0.15)',
+          zIndex: 1000,
+          marginTop: 2,
+          maxHeight: 240,
+          display: 'flex',
+          flexDirection: 'column',
+        }}>
+          {/* 搜索输入 */}
+          <input
+            ref={inputRef}
+            value={filter}
+            onChange={e => setFilter(e.target.value)}
+            onKeyDown={handleKeyDown}
+            placeholder="搜索..."
+            style={{
+              padding: '6px 8px',
+              border: 'none',
+              borderBottom: '1px solid #eee',
+              outline: 'none',
+              fontSize: 13,
+            }}
+          />
+          {/* 选项列表 */}
+          <div style={{ overflowY: 'auto', flex: 1 }}>
+            {loading && (
+              <div style={{ padding: '8px 12px', color: '#999', fontSize: 13 }}>加载中...</div>
+            )}
+            {!loading && filtered.length === 0 && (
+              <div style={{ padding: '8px 12px', color: '#999', fontSize: 13 }}>无匹配项</div>
+            )}
+            {!loading && filtered.map(opt => (
+              <div
+                key={opt.value}
+                onClick={() => { onChange(opt.value); setOpen(false) }}
+                style={{
+                  padding: '8px 12px',
+                  cursor: 'pointer',
+                  background: opt.value === value ? '#e94560' : 'transparent',
+                  color: opt.value === value ? '#fff' : '#333',
+                  fontSize: 13,
+                }}
+                onMouseEnter={e => {
+                  if (opt.value !== value) (e.currentTarget.style.background = '#f5f5f5')
+                }}
+                onMouseLeave={e => {
+                  if (opt.value !== value) (e.currentTarget.style.background = 'transparent')
+                }}
+              >
+                <div>{opt.label}</div>
+                {opt.subtitle && (
+                  <div style={{ fontSize: 11, color: opt.value === value ? 'rgba(255,255,255,0.7)' : '#999', marginTop: 2 }}>
+                    {opt.subtitle}
+                  </div>
+                )}
+              </div>
+            ))}
+          </div>
+        </div>
+      )}
+    </div>
+  )
+}
+
 export function Training() {
   const [modelId, setModelId] = useState('')
   const [modelType, setModelType] = useState('text')
@@ -48,6 +195,27 @@ export function Training() {
   const [loading, setLoading] = useState(false)
   const [submitting, setSubmitting] = useState(false)
 
+  // 模型和数据集列表
+  const [models, setModels] = useState<ModelInfo[]>([])
+  const [datasets, setDatasets] = useState<DatasetInfo[]>([])
+  const [loadingOptions, setLoadingOptions] = useState(true)
+
+  const fetchOptions = useCallback(() => {
+    setLoadingOptions(true)
+    Promise.all([
+      api.models.list().catch(() => []),
+      api.datasets.list().catch(() => []),
+    ]).then(([m, d]) => {
+      setModels(m)
+      setDatasets(d)
+    }).finally(() => setLoadingOptions(false))
+  }, [])
+
+  // 页面加载时获取选项
+  useEffect(() => {
+    fetchOptions()
+  }, [fetchOptions])
+
   // Connect WebSocket on mount
   useEffect(() => {
     wsManager.connect()
@@ -64,7 +232,6 @@ export function Training() {
 
   useEffect(() => {
     fetchJobs()
-    // 每 5 秒轮询一次更新状态
     const interval = setInterval(fetchJobs, 5000)
     return () => clearInterval(interval)
   }, [])
@@ -90,6 +257,7 @@ export function Training() {
         setModelId('')
         setDatasetId('')
         fetchJobs()
+        fetchOptions()
       })
       .catch(console.error)
       .finally(() => setSubmitting(false))
@@ -113,6 +281,19 @@ export function Training() {
     }
   }
 
+  // 构建下拉选项
+  const modelOptions = models.map(m => ({
+    value: m.id,
+    label: m.id,
+    subtitle: `${m.model_type}${m.is_downloaded ? ' ✓' : ''}`,
+  }))
+
+  const datasetOptions = datasets.map(d => ({
+    value: d.id,
+    label: d.name,
+    subtitle: `${d.format} · ${d.record_count} 条`,
+  }))
+
   return (
     <div>
       <h1>训练任务</h1>
@@ -122,8 +303,14 @@ export function Training() {
         <h2 style={{ margin: '0 0 16px', fontSize: 16 }}>创建训练任务</h2>
         <div style={{ display: 'grid', gridTemplateColumns: 'repeat(3, 1fr)', gap: 12 }}>
           <div>
-            <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>模型 ID</label>
-            <input value={modelId} onChange={e => setModelId(e.target.value)} placeholder="模型 ID" style={{ width: '100%', padding: '6px 8px', borderRadius: 4, border: '1px solid #ccc', boxSizing: 'border-box' }} />
+            <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>模型</label>
+            <SearchableSelect
+              options={modelOptions}
+              value={modelId}
+              onChange={setModelId}
+              placeholder="选择模型"
+              loading={loadingOptions}
+            />
           </div>
           <div>
             <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>模型类型</label>
@@ -132,8 +319,14 @@ export function Training() {
             </select>
           </div>
           <div>
-            <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>数据集 ID</label>
-            <input value={datasetId} onChange={e => setDatasetId(e.target.value)} placeholder="数据集 ID" style={{ width: '100%', padding: '6px 8px', borderRadius: 4, border: '1px solid #ccc', boxSizing: 'border-box' }} />
+            <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>数据集</label>
+            <SearchableSelect
+              options={datasetOptions}
+              value={datasetId}
+              onChange={setDatasetId}
+              placeholder="选择数据集"
+              loading={loadingOptions}
+            />
           </div>
           <div>
             <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>训练类型</label>