|
@@ -0,0 +1,160 @@
|
|
|
|
|
+import { useState, useEffect } from 'react'
|
|
|
|
|
+import api, { AdapterInfo } from '../api/client'
|
|
|
|
|
+
|
|
|
|
|
+export function Inference() {
|
|
|
|
|
+ const [adapters, setAdapters] = useState<AdapterInfo[]>([])
|
|
|
|
|
+ const [adapterId, setAdapterId] = useState('')
|
|
|
|
|
+ const [prompt, setPrompt] = useState('')
|
|
|
|
|
+ const [maxTokens, setMaxTokens] = useState(256)
|
|
|
|
|
+ const [temperature, setTemperature] = useState(0.8)
|
|
|
|
|
+ const [topP, setTopP] = useState(0.95)
|
|
|
|
|
+ const [repetitionPenalty, setRepetitionPenalty] = useState(1.1)
|
|
|
|
|
+ const [doSample, setDoSample] = useState(true)
|
|
|
|
|
+ const [generating, setGenerating] = useState(false)
|
|
|
|
|
+ const [result, setResult] = useState<{ generated_text: string; tokens_generated: number } | null>(null)
|
|
|
|
|
+ const [error, setError] = useState('')
|
|
|
|
|
+ const [viewMode, setViewMode] = useState<'full' | 'new'>('new')
|
|
|
|
|
+
|
|
|
|
|
+ useEffect(() => {
|
|
|
|
|
+ api.inference.adapters()
|
|
|
|
|
+ .then(setAdapters)
|
|
|
|
|
+ .catch(() => setAdapters([]))
|
|
|
|
|
+ }, [])
|
|
|
|
|
+
|
|
|
|
|
+ useEffect(() => {
|
|
|
|
|
+ if (adapters.length > 0 && !adapterId) {
|
|
|
|
|
+ setAdapterId(adapters[0].id)
|
|
|
|
|
+ }
|
|
|
|
|
+ }, [adapters])
|
|
|
|
|
+
|
|
|
|
|
+ const handleGenerate = () => {
|
|
|
|
|
+ if (!adapterId.trim() || !prompt.trim()) return
|
|
|
|
|
+ setGenerating(true)
|
|
|
|
|
+ setError('')
|
|
|
|
|
+ setResult(null)
|
|
|
|
|
+ api.inference.generate({
|
|
|
|
|
+ adapter_id: adapterId,
|
|
|
|
|
+ prompt,
|
|
|
|
|
+ max_new_tokens: maxTokens,
|
|
|
|
|
+ temperature,
|
|
|
|
|
+ top_p: topP,
|
|
|
|
|
+ repetition_penalty: repetitionPenalty,
|
|
|
|
|
+ do_sample: doSample,
|
|
|
|
|
+ })
|
|
|
|
|
+ .then(res => {
|
|
|
|
|
+ if (res.error) {
|
|
|
|
|
+ setError(res.error)
|
|
|
|
|
+ } else {
|
|
|
|
|
+ setResult({ generated_text: res.generated_text, tokens_generated: res.tokens_generated })
|
|
|
|
|
+ }
|
|
|
|
|
+ })
|
|
|
|
|
+ .catch(err => setError(err.message))
|
|
|
|
|
+ .finally(() => setGenerating(false))
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ return (
|
|
|
|
|
+ <div>
|
|
|
|
|
+ <h1>模型推理</h1>
|
|
|
|
|
+
|
|
|
|
|
+ {/* Adapter selector */}
|
|
|
|
|
+ <div style={{ marginTop: 16, background: '#fff', borderRadius: 8, padding: 20, boxShadow: '0 1px 3px rgba(0,0,0,0.1)' }}>
|
|
|
|
|
+ <h2 style={{ margin: '0 0 16px', fontSize: 16 }}>选择 Adapter</h2>
|
|
|
|
|
+ {adapters.length === 0 ? (
|
|
|
|
|
+ <p style={{ color: '#999', fontSize: 14 }}>暂无可用的 adapter,请先完成训练任务</p>
|
|
|
|
|
+ ) : (
|
|
|
|
|
+ <select
|
|
|
|
|
+ value={adapterId}
|
|
|
|
|
+ onChange={e => setAdapterId(e.target.value)}
|
|
|
|
|
+ style={{ padding: '6px 12px', borderRadius: 4, border: '1px solid #ccc', width: '100%', maxWidth: 500 }}
|
|
|
|
|
+ >
|
|
|
|
|
+ {adapters.map(a => (
|
|
|
|
|
+ <option key={a.id} value={a.id}>{a.id} (base: {a.base_model})</option>
|
|
|
|
|
+ ))}
|
|
|
|
|
+ </select>
|
|
|
|
|
+ )}
|
|
|
|
|
+ </div>
|
|
|
|
|
+
|
|
|
|
|
+ {/* Prompt input */}
|
|
|
|
|
+ <div style={{ marginTop: 16, background: '#fff', borderRadius: 8, padding: 20, boxShadow: '0 1px 3px rgba(0,0,0,0.1)' }}>
|
|
|
|
|
+ <h2 style={{ margin: '0 0 16px', fontSize: 16 }}>输入提示词</h2>
|
|
|
|
|
+ <textarea
|
|
|
|
|
+ value={prompt}
|
|
|
|
|
+ onChange={e => setPrompt(e.target.value)}
|
|
|
|
|
+ placeholder="输入你的问题或指令..."
|
|
|
|
|
+ rows={4}
|
|
|
|
|
+ style={{ width: '100%', padding: 12, borderRadius: 4, border: '1px solid #ccc', fontSize: 14, boxSizing: 'border-box', resize: 'vertical' }}
|
|
|
|
|
+ />
|
|
|
|
|
+
|
|
|
|
|
+ {/* Generation params */}
|
|
|
|
|
+ <div style={{ marginTop: 12, display: 'grid', gridTemplateColumns: 'repeat(4, 1fr)', gap: 12 }}>
|
|
|
|
|
+ <div>
|
|
|
|
|
+ <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>Max Tokens</label>
|
|
|
|
|
+ <input type="number" value={maxTokens} onChange={e => setMaxTokens(Number(e.target.value))} min={1} max={4096} style={{ width: '100%', padding: '6px 8px', borderRadius: 4, border: '1px solid #ccc', boxSizing: 'border-box' }} />
|
|
|
|
|
+ </div>
|
|
|
|
|
+ <div>
|
|
|
|
|
+ <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>Temperature</label>
|
|
|
|
|
+ <input type="number" value={temperature} onChange={e => setTemperature(Number(e.target.value))} min={0} max={2} step={0.1} style={{ width: '100%', padding: '6px 8px', borderRadius: 4, border: '1px solid #ccc', boxSizing: 'border-box' }} />
|
|
|
|
|
+ </div>
|
|
|
|
|
+ <div>
|
|
|
|
|
+ <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>Top P</label>
|
|
|
|
|
+ <input type="number" value={topP} onChange={e => setTopP(Number(e.target.value))} min={0} max={1} step={0.05} style={{ width: '100%', padding: '6px 8px', borderRadius: 4, border: '1px solid #ccc', boxSizing: 'border-box' }} />
|
|
|
|
|
+ </div>
|
|
|
|
|
+ <div>
|
|
|
|
|
+ <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>Repetition Penalty</label>
|
|
|
|
|
+ <input type="number" value={repetitionPenalty} onChange={e => setRepetitionPenalty(Number(e.target.value))} min={1} max={2} step={0.1} style={{ width: '100%', padding: '6px 8px', borderRadius: 4, border: '1px solid #ccc', boxSizing: 'border-box' }} />
|
|
|
|
|
+ </div>
|
|
|
|
|
+ </div>
|
|
|
|
|
+
|
|
|
|
|
+ <label style={{ display: 'flex', alignItems: 'center', gap: 6, fontSize: 13, cursor: 'pointer', marginTop: 12 }}>
|
|
|
|
|
+ <input type="checkbox" checked={doSample} onChange={e => setDoSample(e.target.checked)} />
|
|
|
|
|
+ 启用采样 (关闭则为 greedy decoding)
|
|
|
|
|
+ </label>
|
|
|
|
|
+
|
|
|
|
|
+ <button
|
|
|
|
|
+ onClick={handleGenerate}
|
|
|
|
|
+ disabled={generating || !adapterId}
|
|
|
|
|
+ style={{ marginTop: 16, padding: '8px 24px', borderRadius: 4, border: 'none', background: '#e94560', color: '#fff', cursor: 'pointer', opacity: generating || !adapterId ? 0.6 : 1 }}
|
|
|
|
|
+ >
|
|
|
|
|
+ {generating ? '生成中...' : '生成'}
|
|
|
|
|
+ </button>
|
|
|
|
|
+ </div>
|
|
|
|
|
+
|
|
|
|
|
+ {/* Error */}
|
|
|
|
|
+ {error && (
|
|
|
|
|
+ <div style={{ marginTop: 16, padding: 16, background: '#ffebee', borderRadius: 8, color: '#c62828' }}>
|
|
|
|
|
+ <strong>错误:</strong> {error}
|
|
|
|
|
+ </div>
|
|
|
|
|
+ )}
|
|
|
|
|
+
|
|
|
|
|
+ {/* Result */}
|
|
|
|
|
+ {result && (
|
|
|
|
|
+ <div style={{ marginTop: 16, background: '#fff', borderRadius: 8, padding: 20, boxShadow: '0 1px 3px rgba(0,0,0,0.1)' }}>
|
|
|
|
|
+ <div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 12 }}>
|
|
|
|
|
+ <h2 style={{ margin: 0, fontSize: 16 }}>生成结果</h2>
|
|
|
|
|
+ <span style={{ fontSize: 12, color: '#999' }}>{result.tokens_generated} tokens</span>
|
|
|
|
|
+ </div>
|
|
|
|
|
+
|
|
|
|
|
+ {/* View mode toggle */}
|
|
|
|
|
+ <div style={{ marginBottom: 12 }}>
|
|
|
|
|
+ <button
|
|
|
|
|
+ onClick={() => setViewMode('full')}
|
|
|
|
|
+ style={{ padding: '4px 12px', borderRadius: 4, border: `1px solid ${viewMode === 'full' ? '#e94560' : '#ccc'}`, background: viewMode === 'full' ? '#e94560' : '#fff', color: viewMode === 'full' ? '#fff' : '#333', cursor: 'pointer', marginRight: 8, fontSize: 13 }}
|
|
|
|
|
+ >
|
|
|
|
|
+ 完整输出
|
|
|
|
|
+ </button>
|
|
|
|
|
+ <button
|
|
|
|
|
+ onClick={() => setViewMode('new')}
|
|
|
|
|
+ style={{ padding: '4px 12px', borderRadius: 4, border: `1px solid ${viewMode === 'new' ? '#e94560' : '#ccc'}`, background: viewMode === 'new' ? '#e94560' : '#fff', color: viewMode === 'new' ? '#fff' : '#333', cursor: 'pointer', fontSize: 13 }}
|
|
|
|
|
+ >
|
|
|
|
|
+ 仅新生成部分
|
|
|
|
|
+ </button>
|
|
|
|
|
+ </div>
|
|
|
|
|
+
|
|
|
|
|
+ <pre style={{ whiteSpace: 'pre-wrap', wordBreak: 'break-word', background: '#f5f5f5', padding: 16, borderRadius: 4, fontSize: 14, lineHeight: 1.6, maxHeight: 400, overflow: 'auto' }}>
|
|
|
|
|
+ {viewMode === 'full' ? result.generated_text : result.generated_text.replace(prompt, '').trim()}
|
|
|
|
|
+ </pre>
|
|
|
|
|
+ </div>
|
|
|
|
|
+ )}
|
|
|
|
|
+ </div>
|
|
|
|
|
+ )
|
|
|
|
|
+}
|