فهرست منبع

优化记录显示,使用单卡推理

lxylxy123321 1 روز پیش
والد
کامیت
5aebcc61f8

+ 6 - 2
backend/app/core/inference_worker.py

@@ -104,8 +104,12 @@ class InferenceWorker:
             self.tokenizer.pad_token = self.tokenizer.eos_token
 
         print(f"[worker] Loading model from: {model_path}", flush=True)
-        # device_map="auto" 自动将模型层分散到所有可见 GPU(由 CUDA_VISIBLE_DEVICES 控制)
-        device_map = "auto" if torch.cuda.is_available() else "cpu"
+        # 单卡加载到 cuda:0(CUDA_VISIBLE_DEVICES 已限制可见 GPU)
+        # 不使用 device_map="auto" 避免多卡时 tied weights 分到不同 GPU 导致报错
+        if torch.cuda.is_available():
+            device_map = {"": 1}
+        else:
+            device_map = "cpu"
         self.model = AutoModelForCausalLM.from_pretrained(
             model_path, torch_dtype=torch.float16, device_map=device_map,
         )

+ 33 - 2
backend/app/engines/text_engine.py

@@ -190,10 +190,21 @@ class TextEngine(BaseEngine):
         world_size = int(os.environ.get("WORLD_SIZE", "1"))
         is_ddp = world_size > 1
 
-        dataset = self._tokenize_dataset(dataset_path, max_seq_length)
+        # SFT 需要预先 tokenize;DPO/PPO 各自处理数据
+        if task_type == "sft":
+            dataset = self._tokenize_dataset(dataset_path, max_seq_length)
+        elif task_type == "dpo":
+            dataset = self._load_dataset_dpo(dataset_path)
+        else:
+            dataset = None  # PPO 在后面单独处理
 
         # 计算总步数(DDP 模式下 Trainer 自动按 world_size 分发数据)
-        dataset_len = len(dataset)
+        if dataset is not None:
+            dataset_len = len(dataset)
+        else:
+            # PPO: 从文件行数估算
+            with open(dataset_path, "r", encoding="utf-8") as f:
+                dataset_len = sum(1 for line in f if line.strip())
         effective_batch = batch_size * gradient_accumulation * world_size
         max_steps = max(1, (dataset_len * epochs) // effective_batch)
 
@@ -529,6 +540,26 @@ class TextEngine(BaseEngine):
         )
         return tokenized_dataset
 
+    def _load_dataset_dpo(self, dataset_path: str):
+        """加载 DPO 数据集,保留 prompt/chosen/rejected 原始文本,由 DPOTrainer 内部 tokenize。"""
+        from datasets import Dataset as HFDataset
+
+        data = []
+        with open(dataset_path, "r", encoding="utf-8") as f:
+            for line in f:
+                line = line.strip()
+                if line:
+                    item = json.loads(line)
+                    prompt = item.get("prompt", item.get("instruction", item.get("input", "")))
+                    chosen = item.get("chosen", item.get("positive", ""))
+                    rejected = item.get("rejected", item.get("negative", ""))
+                    data.append({
+                        "prompt": str(prompt) if prompt else "",
+                        "chosen": str(chosen) if chosen else "",
+                        "rejected": str(rejected) if rejected else "",
+                    })
+        return HFDataset.from_list(data)
+
 
 try:
     from transformers import TrainerCallback as _TrainerCallbackBase

+ 8 - 5
backend/app/preprocessors/__init__.py

@@ -109,11 +109,14 @@ def apply_raw_template(item: dict) -> dict:
 
 def apply_dpo_template(item: dict) -> dict:
     """DPO 模板: prompt + chosen + rejected。"""
-    return {
-        "prompt": item.get("prompt", item.get("input", item.get("question", item.get("query", "")))),
-        "chosen": item.get("chosen", item.get("positive", item.get("answer", ""))),
-        "rejected": item.get("rejected", item.get("negative", "")),
-    }
+    prompt = item.get("prompt", item.get("instruction", item.get("input", item.get("question", item.get("query", "")))))
+    chosen = item.get("chosen", item.get("positive", item.get("answer", "")))
+    rejected = item.get("rejected", item.get("negative", ""))
+    # 确保所有值为字符串
+    prompt = str(prompt) if prompt is not None else ""
+    chosen = str(chosen) if chosen is not None else ""
+    rejected = str(rejected) if rejected is not None else ""
+    return {"prompt": prompt, "chosen": chosen, "rejected": rejected}
 
 
 TEMPLATE_MAP = {

+ 5 - 0
frontend/public/samples/dpo_sample.jsonl

@@ -0,0 +1,5 @@
+{"instruction": "什么是机器学习?", "chosen": "机器学习是人工智能的一个分支,它让计算机能够从数据中自动学习规律,而无需显式编程。常见的类型包括监督学习、无监督学习和强化学习。", "rejected": "机器学习就是让机器自己思考。"}
+{"instruction": "Python 中 list 和 tuple 有什么区别?", "chosen": "list 是可变的,可以增删改元素,用方括号 [] 表示;tuple 是不可变的,创建后不能修改,用圆括号 () 表示。tuple 因为不可变所以可以作为字典的 key。", "rejected": "list 比 tuple 快。"}
+{"instruction": "请解释什么是过拟合。", "chosen": "过拟合是指模型在训练数据上表现很好,但在未见过的测试数据上表现很差。就像学生死记硬背了答案但没有真正理解,换一道题就不会了。常见的解决方法包括增加数据量、正则化、dropout 和早停。", "rejected": "过拟合就是模型太好了,不需要担心。"}
+{"instruction": "什么是梯度下降?", "chosen": "梯度下降是一种优化算法,用于最小化损失函数。它的核心思想是:沿着损失函数梯度的反方向更新参数,就像下山时朝着最陡的方向走,逐步找到最低点。学习率决定了每一步走多远。", "rejected": "梯度下降就是让数字变小。"}
+{"instruction": "解释一下 Transformer 的注意力机制。", "chosen": "注意力机制让模型在处理每个词时,能够关注到输入序列中所有其他词的信息。通过计算 Query、Key、Value 三个向量的相似度来决定关注的权重,多头注意力则从不同角度捕捉语义关系,这是 GPT、BERT 等模型的核心。", "rejected": "注意力机制就是让模型集中精神。"}

+ 8 - 0
frontend/public/samples/ppo_sample.jsonl

@@ -0,0 +1,8 @@
+{"instruction": "请用一句话介绍地球。", "response": "地球是太阳系中从太阳数起的第三颗行星,是目前已知唯一存在生命的天体。"}
+{"instruction": "什么是 API?", "response": "API(应用程序接口)是软件系统之间交互的桥梁,定义了一组规则和协议,让不同的程序可以互相通信和调用功能。"}
+{"instruction": "解释一下什么是 Docker。", "response": "Docker 是一种容器化技术,可以将应用程序及其所有依赖打包成一个轻量级、可移植的容器,确保在任何环境中都能一致运行。"}
+{"instruction": "什么是迁移学习?", "response": "迁移学习是将在一个任务上学到的知识应用到另一个相关任务的机器学习方法,比如用 ImageNet 预训练的模型来做图像分类微调,可以大幅减少所需的数据和训练时间。"}
+{"instruction": "请说明 JSON 和 YAML 的区别。", "response": "JSON 使用花括号和方括号,语法严格,广泛用于 API 数据交换;YAML 使用缩进表示层级,可读性更强,常用于配置文件。YAML 支持注释,JSON 不支持。"}
+{"instruction": "什么是微服务架构?", "response": "微服务架构是将应用拆分为一组小型、独立部署的服务,每个服务负责一个特定的业务功能,通过 API 互相通信。相比单体架构,它更灵活、可扩展,但也增加了运维复杂度。"}
+{"instruction": "解释一下什么是 Git 的分支。", "response": "Git 分支是代码仓库中一条独立的开发线,允许你在不影响主线的情况下进行功能开发或实验。创建分支相当于复制了当前代码状态,可以在上面自由修改,完成后通过 merge 合并回主分支。"}
+{"instruction": "HTTP 状态码 404 是什么意思?", "response": "HTTP 404 表示服务器无法找到请求的资源,通常是因为 URL 路径错误或资源已被删除。这是最常见的客户端错误状态码之一。"}

+ 93 - 0
frontend/src/components/Pagination.tsx

@@ -0,0 +1,93 @@
+interface PaginationProps {
+  page: number
+  totalPages: number
+  total?: number
+  onChange: (page: number) => void
+}
+
+export function Pagination({ page, totalPages, total, onChange }: PaginationProps) {
+  if (totalPages <= 1) return null
+
+  // 计算页码按钮范围(最多显示 5 个)
+  const maxButtons = 5
+  let start = Math.max(1, page - Math.floor(maxButtons / 2))
+  let end = Math.min(totalPages, start + maxButtons - 1)
+  if (end - start + 1 < maxButtons) {
+    start = Math.max(1, end - maxButtons + 1)
+  }
+
+  const pages = Array.from({ length: end - start + 1 }, (_, i) => start + i)
+
+  const btnStyle = (active: boolean, disabled = false): React.CSSProperties => ({
+    padding: '6px 12px',
+    borderRadius: 6,
+    border: `1px solid ${active ? '#14b8a6' : '#cbd5e1'}`,
+    background: active ? '#14b8a6' : '#fff',
+    color: active ? '#fff' : disabled ? '#cbd5e1' : '#64748b',
+    cursor: disabled ? 'not-allowed' : 'pointer',
+    fontSize: 13,
+    fontWeight: active ? 600 : 400,
+    opacity: disabled ? 0.5 : 1,
+    transition: 'all 0.15s ease',
+  })
+
+  return (
+    <div style={{
+      display: 'flex',
+      justifyContent: 'space-between',
+      alignItems: 'center',
+      marginTop: 16,
+      padding: '12px 0',
+    }}>
+      {/* 左侧:统计信息 */}
+      <div style={{ fontSize: 13, color: '#64748b' }}>
+        第 {page} / {totalPages} 页
+        {total !== undefined && `,共 ${total} 条`}
+      </div>
+
+      {/* 右侧:页码按钮 */}
+      <div style={{ display: 'flex', gap: 4, alignItems: 'center' }}>
+        {/* 上一页 */}
+        <button
+          onClick={() => onChange(page - 1)}
+          disabled={page <= 1}
+          style={btnStyle(false, page <= 1)}
+        >
+          上一页
+        </button>
+
+        {/* 页码按钮 */}
+        {start > 1 && (
+          <>
+            <button onClick={() => onChange(1)} style={btnStyle(false)}>1</button>
+            {start > 2 && <span style={{ padding: '0 4px', color: '#94a3b8' }}>...</span>}
+          </>
+        )}
+        {pages.map(p => (
+          <button
+            key={p}
+            onClick={() => onChange(p)}
+            style={btnStyle(p === page)}
+          >
+            {p}
+          </button>
+        ))}
+        {end < totalPages && (
+          <>
+            {end < totalPages - 1 && <span style={{ padding: '0 4px', color: '#94a3b8' }}>...</span>}
+            <button onClick={() => onChange(totalPages)} style={btnStyle(false)}>{totalPages}</button>
+          </>
+        )}
+
+        {/* 下一页 */}
+        <button
+          onClick={() => onChange(page + 1)}
+          disabled={page >= totalPages}
+          style={btnStyle(false, page >= totalPages)}
+        >
+          下一页
+        </button>
+      </div>
+    </div>
+  )
+}

+ 21 - 1
frontend/src/pages/Deployment.tsx

@@ -1,6 +1,9 @@
 import { useState, useEffect, useRef, useCallback } from 'react'
 import api, { DeployResponse, DeployedServiceInfo, TrainingJob, DatasetInfo, ApiKeyInfo, ApiKeyCreateResponse } from '../api/client'
 import { jobLabel } from '../utils/jobLabel'
+import { Pagination } from '../components/Pagination'
+
+const SERVICES_PER_PAGE = 5
 
 type Tab = 'serve' | 'export'
 
@@ -9,6 +12,7 @@ export function Deployment() {
   const [jobs, setJobs] = useState<TrainingJob[]>([])
   const [datasets, setDatasets] = useState<DatasetInfo[]>([])
   const [services, setServices] = useState<DeployedServiceInfo[]>([])
+  const [servicePage, setServicePage] = useState(1)
   const [loadingServices, setLoadingServices] = useState(false)
 
   // 导出状态
@@ -71,6 +75,14 @@ export function Deployment() {
     }
   }, [loadServices, loadApiKeys])
 
+  // 分页计算与自动校正
+  const totalServicePages = Math.max(1, Math.ceil(services.length / SERVICES_PER_PAGE))
+  useEffect(() => {
+    if (servicePage > totalServicePages) setServicePage(totalServicePages)
+  }, [servicePage, totalServicePages])
+
+  const pagedServices = services.slice((servicePage - 1) * SERVICES_PER_PAGE, servicePage * SERVICES_PER_PAGE)
+
   // 轮询部署任务状态
   const startPolling = useCallback((taskId: string, mode: Tab) => {
     if (pollingRef.current) clearInterval(pollingRef.current)
@@ -476,11 +488,19 @@ export function Deployment() {
         {services.length === 0 ? (
           <p style={{ color: '#94a3b8', fontSize: 13, margin: 0 }}>暂无已部署的服务</p>
         ) : (
+          <>
           <div style={{ display: 'flex', flexDirection: 'column', gap: 12 }}>
-            {services.map(svc => (
+            {pagedServices.map(svc => (
               <ServiceCard key={svc.task_id} service={svc} onStop={() => handleStop(svc.task_id)} onRestart={() => handleRestart(svc.task_id)} loading={loadingTaskIds.has(svc.task_id)} />
             ))}
           </div>
+          <Pagination
+            page={servicePage}
+            totalPages={totalServicePages}
+            total={services.length}
+            onChange={setServicePage}
+          />
+          </>
         )}
       </div>
 

+ 23 - 1
frontend/src/pages/Training.tsx

@@ -2,6 +2,9 @@ import { useState, useEffect, useRef, useCallback, memo } from 'react'
 import api, { TrainingJob, ModelInfo, DatasetInfo } from '../api/client'
 import { wsManager } from '../api/websocket'
 import { Train } from 'lucide-react'
+import { Pagination } from '../components/Pagination'
+
+const JOBS_PER_PAGE = 10
 
 const MODEL_TYPES = [
   { value: 'text', label: '文本 (LLaMA/Qwen)' },
@@ -334,6 +337,7 @@ export function Training() {
   const [numGpus, setNumGpus] = useState(1)
 
   const [jobs, setJobs] = useState<TrainingJob[]>([])
+  const [jobPage, setJobPage] = useState(1)
   const [loading, setLoading] = useState(false)
   const [submitting, setSubmitting] = useState(false)
   const [createError, setCreateError] = useState('')
@@ -449,6 +453,15 @@ export function Training() {
     }
   }, [])
 
+  // 页码自动校正
+  const totalJobPages = Math.max(1, Math.ceil(jobs.length / JOBS_PER_PAGE))
+  useEffect(() => {
+    if (jobPage > totalJobPages) setJobPage(totalJobPages)
+  }, [jobPage, totalJobPages])
+
+  // 当前页的 jobs
+  const pagedJobs = jobs.slice((jobPage - 1) * JOBS_PER_PAGE, jobPage * JOBS_PER_PAGE)
+
   const handleCreate = () => {
     if (!modelId.trim() || !datasetId.trim()) return
     setSubmitting(true)
@@ -477,6 +490,7 @@ export function Training() {
         setModelId('')
         setDatasetId('')
         setJobs(prev => prev.filter(j => j.id !== tempId))
+        setJobPage(1)
         fetchJobs()
         fetchOptions()
       })
@@ -636,6 +650,7 @@ export function Training() {
         )}
 
         {!loading && jobs.length > 0 && (
+          <>
           <div style={{
             background: '#fff', borderRadius: 10, overflow: 'hidden',
             boxShadow: '0 1px 3px rgba(0,0,0,0.06)', border: '1px solid rgba(0,0,0,0.04)',
@@ -654,12 +669,19 @@ export function Training() {
                 </tr>
               </thead>
               <tbody>
-                {jobs.map(j => (
+                {pagedJobs.map(j => (
                   <JobRow key={j.id} j={j} onCancel={handleCancel} datasets={datasets} />
                 ))}
               </tbody>
             </table>
           </div>
+          <Pagination
+            page={jobPage}
+            totalPages={totalJobPages}
+            total={jobs.length}
+            onChange={setJobPage}
+          />
+          </>
         )}
       </div>
     </div>