Kaynağa Gözat

增加推理功能,支持modelscope下载

lxylxy123321 2 hafta önce
ebeveyn
işleme
be25d3fc33

+ 55 - 0
backend/app/api/inference.py

@@ -0,0 +1,55 @@
+from pydantic import BaseModel, Field
+
+from app.services import inference_service
+
+from fastapi import APIRouter
+
+router = APIRouter()
+
+
+class GenerateRequest(BaseModel):
+    adapter_id: str = Field(..., description="训练任务 ID(adapter 目录名)")
+    prompt: str
+    max_new_tokens: int = 256
+    temperature: float = 0.8
+    top_p: float = 0.95
+    repetition_penalty: float = 1.1
+    do_sample: bool = True
+
+
+class GenerateResponse(BaseModel):
+    prompt: str
+    generated_text: str
+    generated_only: str
+    tokens_generated: int
+    error: str | None = None
+
+
+@router.post("/generate", response_model=GenerateResponse)
+async def generate(req: GenerateRequest):
+    """使用已训练的 adapter 生成文本。"""
+    from app.config import get_settings
+    settings = get_settings()
+    adapter_path = str(settings.adapters_dir / req.adapter_id)
+    result = await inference_service.generate(
+        adapter_path,
+        req.prompt,
+        req.max_new_tokens,
+        req.temperature,
+        req.top_p,
+        req.repetition_penalty,
+        req.do_sample,
+    )
+    return GenerateResponse(
+        prompt=result.get("prompt", req.prompt),
+        generated_text=result.get("generated_text", ""),
+        generated_only=result.get("generated_only", ""),
+        tokens_generated=result.get("tokens_generated", 0),
+        error=result.get("error"),
+    )
+
+
+@router.get("/adapters")
+async def list_adapters():
+    """列出所有可用的 adapter。"""
+    return await inference_service.get_available_adapters()

+ 110 - 0
backend/app/services/inference_service.py

@@ -0,0 +1,110 @@
+import json
+from pathlib import Path
+from typing import Any
+
+from app.config import get_settings
+from app.core.db import async_session, TrainingJobModel
+from app.core.logging import logger
+from sqlalchemy import select
+
+settings = get_settings()
+
+
+async def generate(
+    adapter_path: str,
+    prompt: str,
+    max_new_tokens: int = 256,
+    temperature: float = 0.8,
+    top_p: float = 0.95,
+    repetition_penalty: float = 1.1,
+    do_sample: bool = True,
+) -> dict[str, Any]:
+    """使用已训练的 adapter 生成文本。"""
+    try:
+        import torch
+        from transformers import AutoModelForCausalLM, AutoTokenizer
+
+        # 推断 base model
+        base_model_id = _get_base_model_id(adapter_path)
+        if not base_model_id:
+            # 尝试从训练记录中获取
+            return {"error": "无法找到基础模型信息,请确保训练任务已完成"}
+
+        # 加载 tokenizer
+        tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True)
+        if tokenizer.pad_token is None:
+            tokenizer.pad_token = tokenizer.eos_token
+
+        # 加载 base model + adapter
+        from peft import PeftModel
+
+        base_model = AutoModelForCausalLM.from_pretrained(
+            base_model_id,
+            torch_dtype=torch.float16,
+            device_map="auto",
+        )
+        model = PeftModel.from_pretrained(base_model, adapter_path)
+        model.eval()
+
+        # Tokenize prompt
+        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
+
+        # Generate
+        with torch.no_grad():
+            outputs = model.generate(
+                **inputs,
+                max_new_tokens=max_new_tokens,
+                temperature=temperature,
+                top_p=top_p,
+                repetition_penalty=repetition_penalty,
+                do_sample=do_sample,
+                pad_token_id=tokenizer.eos_token_id,
+            )
+
+        # Decode
+        generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
+        generated_only = tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
+
+        return {
+            "prompt": prompt,
+            "generated_text": generated_text,
+            "generated_only": generated_only,
+            "tokens_generated": outputs.shape[1] - inputs["input_ids"].shape[1],
+        }
+
+    except Exception as e:
+        logger.error(f"Inference failed: {e}")
+        return {"error": str(e)}
+
+
+def _get_base_model_id(adapter_path: str) -> str | None:
+    """从 adapter config 中获取 base model ID。"""
+    config_path = Path(adapter_path) / "adapter_config.json"
+    if config_path.exists():
+        with open(config_path) as f:
+            cfg = json.load(f)
+        return cfg.get("base_model_name_or_path")
+    return None
+
+
+async def get_available_adapters() -> list[dict[str, Any]]:
+    """列出所有已训练的 adapter。"""
+    adapters_dir = settings.adapters_dir
+    if not adapters_dir.exists():
+        return []
+
+    result = []
+    for d in adapters_dir.iterdir():
+        if not d.is_dir():
+            continue
+        adapter_config = d / "adapter_config.json"
+        if adapter_config.exists():
+            with open(adapter_config) as f:
+                cfg = json.load(f)
+            result.append({
+                "id": d.name,
+                "path": str(d),
+                "base_model": cfg.get("base_model_name_or_path", "unknown"),
+                "peft_type": cfg.get("peft_type", "unknown"),
+            })
+    return result

+ 4 - 1
backend/app/services/model_service.py

@@ -55,8 +55,11 @@ async def download_model(model_id: str, use_modelscope: bool = False) -> dict[st
         logger.info(f"Model downloaded: {model_id} -> {local_path}")
         return {"model_id": model_id, "status": "completed", "path": local_path}
     except Exception as e:
+        error_msg = str(e)
+        if "Connection" in error_msg or "timeout" in error_msg.lower() or "network" in error_msg.lower():
+            error_msg += "\n提示: 可能是 HuggingFace 网络问题。尝试使用 ModelScope 下载。"
         logger.error(f"Model download failed: {e}")
-        return {"model_id": model_id, "status": "failed", "error": str(e)}
+        return {"model_id": model_id, "status": "failed", "error": error_msg}
 
 
 def list_cached_models() -> list[dict[str, Any]]:

+ 3 - 1
backend/main.py

@@ -44,18 +44,20 @@ def create_app() -> FastAPI:
         allow_headers=["*"],
     )
 
-    # 挂载路由 (Phase 2 起逐步填充)
+    # 挂载路由
     from app.api import models as models_api
     from app.api import datasets as datasets_api
     from app.api import training as training_api
     from app.api import evaluation as evaluation_api
     from app.api import deployment as deployment_api
+    from app.api import inference as inference_api
 
     app.include_router(models_api.router, prefix="/api/v1/models", tags=["models"])
     app.include_router(datasets_api.router, prefix="/api/v1/datasets", tags=["datasets"])
     app.include_router(training_api.router, prefix="/api/v1/training", tags=["training"])
     app.include_router(evaluation_api.router, prefix="/api/v1/evaluation", tags=["evaluation"])
     app.include_router(deployment_api.router, prefix="/api/v1/deployment", tags=["deployment"])
+    app.include_router(inference_api.router, prefix="/api/v1/inference", tags=["inference"])
 
     # WebSocket
     from app.core.websocket import router as ws_router

+ 1 - 0
backend/requirements.txt

@@ -19,6 +19,7 @@ scipy>=1.14.0
 scikit-learn>=1.5.0
 pillow>=10.4.0
 huggingface_hub>=0.25.0
+modelscope>=1.15.0
 pandas>=2.2.0
 pyarrow>=17.0.0
 sentencepiece>=0.2.0

+ 2 - 0
frontend/src/App.tsx

@@ -6,6 +6,7 @@ import { Datasets } from './pages/Datasets'
 import { Training } from './pages/Training'
 import { Evaluation } from './pages/Evaluation'
 import { Deployment } from './pages/Deployment'
+import { Inference } from './pages/Inference'
 
 export default function App() {
   return (
@@ -17,6 +18,7 @@ export default function App() {
         <Route path="/training" element={<Training />} />
         <Route path="/evaluation" element={<Evaluation />} />
         <Route path="/deployment" element={<Deployment />} />
+        <Route path="/inference" element={<Inference />} />
       </Routes>
     </Layout>
   )

+ 38 - 1
frontend/src/api/client.ts

@@ -75,6 +75,18 @@ const api = {
     status: (id: string) =>
       fetch(`/api/v1/deployment/${id}/status`).then(r => r.json()) as Promise<DeployResponse>,
   },
+
+  // --- Inference ---
+  inference: {
+    generate: (req: InferenceRequest) =>
+      fetch('/api/v1/inference/generate', {
+        method: 'POST',
+        headers: { 'Content-Type': 'application/json' },
+        body: JSON.stringify(req),
+      }).then(r => r.json()) as Promise<InferenceResponse>,
+    adapters: () =>
+      fetch('/api/v1/inference/adapters').then(r => r.json()) as Promise<AdapterInfo[]>,
+  },
 }
 
 export default api
@@ -195,4 +207,29 @@ interface DeployResponse {
   error?: string
 }
 
-export type { ModelInfo, ModelDownloadResponse, DatasetInfo, DatasetDownloadResponse, DatasetPreview, DatasetValidation, TrainingJob, TrainingConfig, EvalConfig, EvalResult, DeployConfig, DeployResponse }
+interface AdapterInfo {
+  id: string
+  path: string
+  base_model: string
+  peft_type: string
+}
+
+interface InferenceRequest {
+  adapter_id: string
+  prompt: string
+  max_new_tokens?: number
+  temperature?: number
+  top_p?: number
+  repetition_penalty?: number
+  do_sample?: boolean
+}
+
+interface InferenceResponse {
+  prompt: string
+  generated_text: string
+  generated_only: string
+  tokens_generated: number
+  error?: string
+}
+
+export type { ModelInfo, ModelDownloadResponse, DatasetInfo, DatasetDownloadResponse, DatasetPreview, DatasetValidation, TrainingJob, TrainingConfig, EvalConfig, EvalResult, DeployConfig, DeployResponse, AdapterInfo, InferenceRequest, InferenceResponse }

+ 1 - 0
frontend/src/components/layout/Layout.tsx

@@ -7,6 +7,7 @@ const NAV_ITEMS = [
   { path: '/training', label: '训练' },
   { path: '/evaluation', label: '评估' },
   { path: '/deployment', label: '部署' },
+  { path: '/inference', label: '推理' },
 ]
 
 export function Layout({ children }: { children: React.ReactNode }) {

+ 160 - 0
frontend/src/pages/Inference.tsx

@@ -0,0 +1,160 @@
+import { useState, useEffect } from 'react'
+import api, { AdapterInfo } from '../api/client'
+
+export function Inference() {
+  const [adapters, setAdapters] = useState<AdapterInfo[]>([])
+  const [adapterId, setAdapterId] = useState('')
+  const [prompt, setPrompt] = useState('')
+  const [maxTokens, setMaxTokens] = useState(256)
+  const [temperature, setTemperature] = useState(0.8)
+  const [topP, setTopP] = useState(0.95)
+  const [repetitionPenalty, setRepetitionPenalty] = useState(1.1)
+  const [doSample, setDoSample] = useState(true)
+  const [generating, setGenerating] = useState(false)
+  const [result, setResult] = useState<{ generated_text: string; tokens_generated: number } | null>(null)
+  const [error, setError] = useState('')
+  const [viewMode, setViewMode] = useState<'full' | 'new'>('new')
+
+  useEffect(() => {
+    api.inference.adapters()
+      .then(setAdapters)
+      .catch(() => setAdapters([]))
+  }, [])
+
+  useEffect(() => {
+    if (adapters.length > 0 && !adapterId) {
+      setAdapterId(adapters[0].id)
+    }
+  }, [adapters])
+
+  const handleGenerate = () => {
+    if (!adapterId.trim() || !prompt.trim()) return
+    setGenerating(true)
+    setError('')
+    setResult(null)
+    api.inference.generate({
+      adapter_id: adapterId,
+      prompt,
+      max_new_tokens: maxTokens,
+      temperature,
+      top_p: topP,
+      repetition_penalty: repetitionPenalty,
+      do_sample: doSample,
+    })
+      .then(res => {
+        if (res.error) {
+          setError(res.error)
+        } else {
+          setResult({ generated_text: res.generated_text, tokens_generated: res.tokens_generated })
+        }
+      })
+      .catch(err => setError(err.message))
+      .finally(() => setGenerating(false))
+  }
+
+  return (
+    <div>
+      <h1>模型推理</h1>
+
+      {/* Adapter selector */}
+      <div style={{ marginTop: 16, background: '#fff', borderRadius: 8, padding: 20, boxShadow: '0 1px 3px rgba(0,0,0,0.1)' }}>
+        <h2 style={{ margin: '0 0 16px', fontSize: 16 }}>选择 Adapter</h2>
+        {adapters.length === 0 ? (
+          <p style={{ color: '#999', fontSize: 14 }}>暂无可用的 adapter,请先完成训练任务</p>
+        ) : (
+          <select
+            value={adapterId}
+            onChange={e => setAdapterId(e.target.value)}
+            style={{ padding: '6px 12px', borderRadius: 4, border: '1px solid #ccc', width: '100%', maxWidth: 500 }}
+          >
+            {adapters.map(a => (
+              <option key={a.id} value={a.id}>{a.id} (base: {a.base_model})</option>
+            ))}
+          </select>
+        )}
+      </div>
+
+      {/* Prompt input */}
+      <div style={{ marginTop: 16, background: '#fff', borderRadius: 8, padding: 20, boxShadow: '0 1px 3px rgba(0,0,0,0.1)' }}>
+        <h2 style={{ margin: '0 0 16px', fontSize: 16 }}>输入提示词</h2>
+        <textarea
+          value={prompt}
+          onChange={e => setPrompt(e.target.value)}
+          placeholder="输入你的问题或指令..."
+          rows={4}
+          style={{ width: '100%', padding: 12, borderRadius: 4, border: '1px solid #ccc', fontSize: 14, boxSizing: 'border-box', resize: 'vertical' }}
+        />
+
+        {/* Generation params */}
+        <div style={{ marginTop: 12, display: 'grid', gridTemplateColumns: 'repeat(4, 1fr)', gap: 12 }}>
+          <div>
+            <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>Max Tokens</label>
+            <input type="number" value={maxTokens} onChange={e => setMaxTokens(Number(e.target.value))} min={1} max={4096} style={{ width: '100%', padding: '6px 8px', borderRadius: 4, border: '1px solid #ccc', boxSizing: 'border-box' }} />
+          </div>
+          <div>
+            <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>Temperature</label>
+            <input type="number" value={temperature} onChange={e => setTemperature(Number(e.target.value))} min={0} max={2} step={0.1} style={{ width: '100%', padding: '6px 8px', borderRadius: 4, border: '1px solid #ccc', boxSizing: 'border-box' }} />
+          </div>
+          <div>
+            <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>Top P</label>
+            <input type="number" value={topP} onChange={e => setTopP(Number(e.target.value))} min={0} max={1} step={0.05} style={{ width: '100%', padding: '6px 8px', borderRadius: 4, border: '1px solid #ccc', boxSizing: 'border-box' }} />
+          </div>
+          <div>
+            <label style={{ display: 'block', fontSize: 12, color: '#666', marginBottom: 4 }}>Repetition Penalty</label>
+            <input type="number" value={repetitionPenalty} onChange={e => setRepetitionPenalty(Number(e.target.value))} min={1} max={2} step={0.1} style={{ width: '100%', padding: '6px 8px', borderRadius: 4, border: '1px solid #ccc', boxSizing: 'border-box' }} />
+          </div>
+        </div>
+
+        <label style={{ display: 'flex', alignItems: 'center', gap: 6, fontSize: 13, cursor: 'pointer', marginTop: 12 }}>
+          <input type="checkbox" checked={doSample} onChange={e => setDoSample(e.target.checked)} />
+          启用采样 (关闭则为 greedy decoding)
+        </label>
+
+        <button
+          onClick={handleGenerate}
+          disabled={generating || !adapterId}
+          style={{ marginTop: 16, padding: '8px 24px', borderRadius: 4, border: 'none', background: '#e94560', color: '#fff', cursor: 'pointer', opacity: generating || !adapterId ? 0.6 : 1 }}
+        >
+          {generating ? '生成中...' : '生成'}
+        </button>
+      </div>
+
+      {/* Error */}
+      {error && (
+        <div style={{ marginTop: 16, padding: 16, background: '#ffebee', borderRadius: 8, color: '#c62828' }}>
+          <strong>错误:</strong> {error}
+        </div>
+      )}
+
+      {/* Result */}
+      {result && (
+        <div style={{ marginTop: 16, background: '#fff', borderRadius: 8, padding: 20, boxShadow: '0 1px 3px rgba(0,0,0,0.1)' }}>
+          <div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 12 }}>
+            <h2 style={{ margin: 0, fontSize: 16 }}>生成结果</h2>
+            <span style={{ fontSize: 12, color: '#999' }}>{result.tokens_generated} tokens</span>
+          </div>
+
+          {/* View mode toggle */}
+          <div style={{ marginBottom: 12 }}>
+            <button
+              onClick={() => setViewMode('full')}
+              style={{ padding: '4px 12px', borderRadius: 4, border: `1px solid ${viewMode === 'full' ? '#e94560' : '#ccc'}`, background: viewMode === 'full' ? '#e94560' : '#fff', color: viewMode === 'full' ? '#fff' : '#333', cursor: 'pointer', marginRight: 8, fontSize: 13 }}
+            >
+              完整输出
+            </button>
+            <button
+              onClick={() => setViewMode('new')}
+              style={{ padding: '4px 12px', borderRadius: 4, border: `1px solid ${viewMode === 'new' ? '#e94560' : '#ccc'}`, background: viewMode === 'new' ? '#e94560' : '#fff', color: viewMode === 'new' ? '#fff' : '#333', cursor: 'pointer', fontSize: 13 }}
+            >
+              仅新生成部分
+            </button>
+          </div>
+
+          <pre style={{ whiteSpace: 'pre-wrap', wordBreak: 'break-word', background: '#f5f5f5', padding: 16, borderRadius: 4, fontSize: 14, lineHeight: 1.6, maxHeight: 400, overflow: 'auto' }}>
+            {viewMode === 'full' ? result.generated_text : result.generated_text.replace(prompt, '').trim()}
+          </pre>
+        </div>
+      )}
+    </div>
+  )
+}