Prechádzať zdrojové kódy

增加模型测试功能

lxylxy123321 2 týždňov pred
rodič
commit
3b8acd9769

+ 13 - 1
backend/app/api/models.py

@@ -1,7 +1,8 @@
 from fastapi import APIRouter, HTTPException
 
 from app.schemas.model import ModelDownloadRequest, ModelDownloadResponse, ModelInfo
-from app.services import model_service
+from app.schemas.model_test import ModelTestRequest, ModelTestResponse
+from app.services import model_service, model_test_service
 
 router = APIRouter()
 
@@ -62,3 +63,14 @@ async def delete_model(model_id: str):
     """删除已缓存的模型(数据库记录 + 本地文件)。"""
     result = await model_service.delete_model(model_id)
     return result
+
+
+@router.post("/test", response_model=ModelTestResponse)
+async def test_model(req: ModelTestRequest):
+    """快速测试已缓存模型的生成能力。"""
+    result = await model_test_service.test_model(
+        req.model_id, req.prompt, req.max_new_tokens, req.temperature, req.top_p,
+    )
+    if "error" in result:
+        raise HTTPException(status_code=400, detail=result["error"])
+    return ModelTestResponse(**result)

+ 16 - 0
backend/app/schemas/model_test.py

@@ -0,0 +1,16 @@
+from pydantic import BaseModel, Field
+
+
+class ModelTestRequest(BaseModel):
+    model_id: str = Field(..., description="已缓存模型的 ID")
+    prompt: str
+    max_new_tokens: int = 128
+    temperature: float = 0.8
+    top_p: float = 0.95
+
+
+class ModelTestResponse(BaseModel):
+    model_id: str
+    prompt: str
+    generated_text: str
+    error: str | None = None

+ 55 - 0
backend/app/services/model_test_service.py

@@ -0,0 +1,55 @@
+import json
+from pathlib import Path
+from typing import Any
+
+from app.config import get_settings
+from app.core.logging import logger
+
+settings = get_settings()
+
+
+async def test_model(model_id: str, prompt: str, max_new_tokens: int = 128, temperature: float = 0.8, top_p: float = 0.95) -> dict[str, Any]:
+    """加载已缓存模型并生成测试响应。"""
+    try:
+        import torch
+        from transformers import AutoModelForCausalLM, AutoTokenizer
+
+        # 查找模型路径
+        model_dir = settings.models_dir / model_id.replace("/", "_")
+        if not (model_dir / "config.json").exists():
+            return {"error": f"Model directory not found: {model_dir}"}
+
+        tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
+        if tokenizer.pad_token is None:
+            tokenizer.pad_token = tokenizer.eos_token
+
+        model = AutoModelForCausalLM.from_pretrained(
+            model_dir,
+            torch_dtype=torch.float16,
+            device_map="auto",
+        )
+        model.eval()
+
+        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
+
+        with torch.no_grad():
+            outputs = model.generate(
+                **inputs,
+                max_new_tokens=max_new_tokens,
+                temperature=temperature,
+                top_p=top_p,
+                do_sample=temperature > 0,
+                pad_token_id=tokenizer.eos_token_id,
+            )
+
+        generated_text = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
+
+        return {
+            "model_id": model_id,
+            "prompt": prompt,
+            "generated_text": generated_text,
+        }
+
+    except Exception as e:
+        logger.error(f"Model test failed: {e}")
+        return {"error": str(e)}

+ 22 - 1
frontend/src/api/client.ts

@@ -30,6 +30,12 @@ const api = {
       apiFetch(`/api/v1/models/${encodeURIComponent(modelId)}`, { method: 'DELETE' }).then(r => r.json()),
     getInfo: (modelId: string) =>
       apiFetch(`/api/v1/models/${encodeURIComponent(modelId)}`).then(r => r.json()) as Promise<ModelInfo>,
+    test: (req: ModelTestRequest) =>
+      apiFetch('/api/v1/models/test', {
+        method: 'POST',
+        headers: { 'Content-Type': 'application/json' },
+        body: JSON.stringify(req),
+      }).then(r => r.json()) as Promise<ModelTestResponse>,
   },
 
   // --- Datasets ---
@@ -120,6 +126,21 @@ interface ModelInfo {
   supported_peft_methods: string[]
 }
 
+interface ModelTestRequest {
+  model_id: string
+  prompt: string
+  max_new_tokens?: number
+  temperature?: number
+  top_p?: number
+}
+
+interface ModelTestResponse {
+  model_id: string
+  prompt: string
+  generated_text: string
+  error?: string
+}
+
 interface ModelDownloadResponse {
   model_id: string
   status: string
@@ -249,4 +270,4 @@ interface InferenceResponse {
   error?: string
 }
 
-export type { ModelInfo, ModelDownloadResponse, DatasetInfo, DatasetDownloadResponse, DatasetPreview, DatasetValidation, TrainingJob, TrainingConfig, EvalConfig, EvalResult, DeployConfig, DeployResponse, AdapterInfo, InferenceRequest, InferenceResponse }
+export type { ModelInfo, ModelTestRequest, ModelTestResponse, ModelDownloadResponse, DatasetInfo, DatasetDownloadResponse, DatasetPreview, DatasetValidation, TrainingJob, TrainingConfig, EvalConfig, EvalResult, DeployConfig, DeployResponse, AdapterInfo, InferenceRequest, InferenceResponse }

+ 94 - 0
frontend/src/pages/Models.tsx

@@ -9,6 +9,13 @@ export function Models() {
   const [loading, setLoading] = useState(false)
   const [statusMsg, setStatusMsg] = useState('')
 
+  // Test state
+  const [testModelId, setTestModelId] = useState('')
+  const [testPrompt, setTestPrompt] = useState('')
+  const [testResult, setTestResult] = useState('')
+  const [testError, setTestError] = useState('')
+  const [testing, setTesting] = useState(false)
+
   const fetchModels = () => {
     setLoading(true)
     api.models.list()
@@ -38,6 +45,37 @@ export function Models() {
     }
   }
 
+  const handleTest = async (id: string) => {
+    setTestModelId(id)
+    setTestPrompt('')
+    setTestResult('')
+    setTestError('')
+    // Show test panel
+    setTestPrompt('你好,请简单介绍一下自己。')
+  }
+
+  const handleTestSubmit = async () => {
+    if (!testModelId.trim() || !testPrompt.trim()) return
+    setTesting(true)
+    setTestResult('')
+    setTestError('')
+    try {
+      const res = await api.models.test({
+        model_id: testModelId,
+        prompt: testPrompt,
+        max_new_tokens: 128,
+        temperature: 0.8,
+        top_p: 0.95,
+      })
+      setTestResult(res.generated_text)
+    } catch (err) {
+      const msg = err instanceof Error ? err.message : '测试失败'
+      setTestError(msg)
+    } finally {
+      setTesting(false)
+    }
+  }
+
   return (
     <div>
       <h1>模型注册</h1>
@@ -104,6 +142,9 @@ export function Models() {
                   </td>
                   <td>{m.supported_peft_methods.join(', ') || '-'}</td>
                   <td>
+                    {m.is_downloaded && (
+                      <button onClick={() => handleTest(m.id)} style={{ marginRight: 8, padding: '2px 8px', color: '#2196f3', border: '1px solid #2196f3', borderRadius: 4, background: 'transparent', cursor: 'pointer' }}>测试</button>
+                    )}
                     <button onClick={() => handleDelete(m.id, m.name)} style={{ padding: '2px 8px', color: '#e94560', border: '1px solid #e94560', borderRadius: 4, background: 'transparent', cursor: 'pointer' }}>删除</button>
                   </td>
                 </tr>
@@ -112,6 +153,59 @@ export function Models() {
           </table>
         )}
       </div>
+
+      {/* Test Panel */}
+      {testModelId && (
+        <div style={{ marginTop: 24, background: '#fff', borderRadius: 8, padding: 20, boxShadow: '0 1px 3px rgba(0,0,0,0.1)' }}>
+          <div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 16 }}>
+            <h2 style={{ margin: 0, fontSize: 16 }}>模型测试 — {testModelId}</h2>
+            <button onClick={() => { setTestModelId(''); setTestResult(''); setTestError(''); setTestPrompt('') }} style={{ padding: '4px 12px', borderRadius: 4, border: '1px solid #ccc', background: '#fff', cursor: 'pointer' }}>关闭</button>
+          </div>
+
+          {/* Chat-like input */}
+          <div style={{ display: 'flex', gap: 8 }}>
+            <input
+              value={testPrompt}
+              onChange={e => setTestPrompt(e.target.value)}
+              onKeyDown={e => { if (e.key === 'Enter') handleTestSubmit() }}
+              placeholder="输入提示词,按 Enter 发送..."
+              style={{ flex: 1, padding: '10px 12px', borderRadius: 4, border: '1px solid #ccc', fontSize: 14 }}
+            />
+            <button
+              onClick={handleTestSubmit}
+              disabled={testing}
+              style={{ padding: '10px 20px', borderRadius: 4, border: 'none', background: '#2196f3', color: '#fff', cursor: 'pointer', opacity: testing ? 0.6 : 1, whiteSpace: 'nowrap' }}
+            >
+              {testing ? '生成中...' : '发送'}
+            </button>
+          </div>
+
+          {/* Error */}
+          {testError && (
+            <div style={{ marginTop: 12, padding: 12, background: '#ffebee', borderRadius: 4, color: '#c62828', fontSize: 13 }}>
+              ❌ {testError}
+            </div>
+          )}
+
+          {/* Result */}
+          {testResult && (
+            <div style={{ marginTop: 12 }}>
+              <div style={{ display: 'flex', gap: 8, marginBottom: 8 }}>
+                <span style={{ fontSize: 12, color: '#999', background: '#f5f5f5', padding: '2px 8px', borderRadius: 4 }}>Prompt</span>
+              </div>
+              <div style={{ padding: 12, background: '#f0f7ff', borderRadius: 4, fontSize: 14, lineHeight: 1.6, marginBottom: 12 }}>
+                {testPrompt}
+              </div>
+              <div style={{ display: 'flex', gap: 8, marginBottom: 8 }}>
+                <span style={{ fontSize: 12, color: '#999', background: '#f5f5f5', padding: '2px 8px', borderRadius: 4 }}>Response</span>
+              </div>
+              <div style={{ padding: 12, background: '#f0fff0', borderRadius: 4, fontSize: 14, lineHeight: 1.6, whiteSpace: 'pre-wrap', wordBreak: 'break-word' }}>
+                {testResult}
+              </div>
+            </div>
+          )}
+        </div>
+      )}
     </div>
   )
 }