Răsfoiți Sursa

新增样本中心样本源

lxylxy123321 5 zile în urmă
părinte
comite
d1809811a0

+ 5 - 0
backend/.env

@@ -53,3 +53,8 @@ SSO_LOGOUT_REDIRECT_URL=http://192.168.92.61:9200/login
 JWT_SECRET_KEY=change-me-in-production-use-a-long-random-string
 JWT_ACCESS_EXPIRE_MINUTES=20
 JWT_REFRESH_EXPIRE_HOURS=24
+
+# --- 样本中心 ---
+SAMPLE_CENTER_BASE_URL=http://192.168.92.61
+SAMPLE_CENTER_APP_ID=WviiGL8KQE20tQhmhQPQhhJ5QpFK51F6
+SAMPLE_CENTER_APP_SECRET=9WXP88hEHJiHRSiUdmx7ip5oQPzY0bnJNsEswQoO4sk6juCplyJTcnAiZsv7e3lJ

+ 51 - 0
backend/app/api/sample_center.py

@@ -0,0 +1,51 @@
+"""样本中心 API 路由。"""
+
+from fastapi import APIRouter, Query, HTTPException
+
+from app.schemas.sample_center import (
+    KnowledgeBaseListResponse,
+    KnowledgeBaseDetailResponse,
+    KbImportResponse,
+)
+from app.services import sample_center_service
+
+router = APIRouter()
+
+
+@router.get("/knowledge-bases", response_model=KnowledgeBaseListResponse)
+async def get_knowledge_bases(
+    page: int = Query(default=1, ge=1),
+    page_size: int = Query(default=20, ge=1, le=100),
+):
+    """获取样本中心知识库列表。"""
+    try:
+        data = await sample_center_service.list_knowledge_bases(page, page_size)
+        return KnowledgeBaseListResponse(**data)
+    except ValueError as e:
+        raise HTTPException(status_code=503, detail=str(e))
+    except Exception as e:
+        raise HTTPException(status_code=502, detail=f"样本中心请求失败: {str(e)}")
+
+
+@router.get("/knowledge-bases/{kb_id}", response_model=KnowledgeBaseDetailResponse)
+async def get_knowledge_base_detail(kb_id: str):
+    """获取知识库详情。"""
+    try:
+        data = await sample_center_service.get_knowledge_base_detail(kb_id)
+        return KnowledgeBaseDetailResponse(**data)
+    except ValueError as e:
+        raise HTTPException(status_code=503, detail=str(e))
+    except Exception as e:
+        raise HTTPException(status_code=502, detail=f"样本中心请求失败: {str(e)}")
+
+
+@router.post("/knowledge-bases/{kb_id}/import", response_model=KbImportResponse)
+async def import_from_knowledge_base(kb_id: str, kb_name: str = ""):
+    """从知识库导入数据到训练数据集。"""
+    try:
+        data = await sample_center_service.import_kb_to_dataset(kb_id, kb_name)
+        return KbImportResponse(**data)
+    except ValueError as e:
+        raise HTTPException(status_code=503, detail=str(e))
+    except Exception as e:
+        raise HTTPException(status_code=502, detail=f"样本中心请求失败: {str(e)}")

+ 5 - 0
backend/app/config.py

@@ -111,6 +111,11 @@ class Settings(BaseSettings):
     compute_node_remote_env: str = "production"
     compute_node_ssh_timeout: int = 300  # SSH 命令超时(秒)
 
+    # --- 样本中心 ---
+    sample_center_base_url: str = ""  # 样本中心 API 地址,如 https://sample.example.com
+    sample_center_app_id: str = ""  # 样本中心应用标识
+    sample_center_app_secret: str = ""  # 样本中心应用密钥
+
     # --- SSO 统一认证 ---
     sso_base_url: str = "http://192.168.92.61:8200"
     sso_client_id: str = "hmDeOtXZVbeo2AZ-x58yPssZLg4Tcb1W"

+ 48 - 0
backend/app/schemas/sample_center.py

@@ -0,0 +1,48 @@
+from pydantic import BaseModel
+
+
+class KnowledgeBaseItem(BaseModel):
+    id: str
+    name: str
+    parent_table: str
+    child_table: str
+    document_count: int
+    status: int
+    created_at: str
+    created_by: str
+    metadata_schema: list[dict] = []
+
+
+class KnowledgeBaseListResponse(BaseModel):
+    total: int
+    page: int
+    page_size: int
+    items: list[KnowledgeBaseItem]
+
+
+class KnowledgeBaseDetailResponse(BaseModel):
+    id: str
+    name: str
+    description: str = ""
+    parent_table: str
+    child_table: str
+    document_count: int
+    status: int
+    created_at: str
+    updated_at: str = ""
+    created_by: str
+    metadata_schema: list[dict] = []
+
+
+class ImportTaskResponse(BaseModel):
+    task_id: str
+    status: str
+
+
+class KbImportResponse(BaseModel):
+    kb_id: str
+    kb_name: str
+    document_count: int
+    metadata_schema: list[dict] = []
+    parent_table: str
+    child_table: str

+ 180 - 0
backend/app/services/sample_center_service.py

@@ -0,0 +1,180 @@
+"""样本中心 API 客户端服务。"""
+
+import httpx
+import time
+from typing import Any
+
+from app.config import get_settings
+from app.core.logging import logger
+
+settings = get_settings()
+
+# Token 缓存(内存中)
+_token_cache: dict[str, Any] = {}
+
+
+def _get_base_url() -> str:
+    if not settings.sample_center_base_url:
+        raise ValueError("样本中心地址未配置,请检查 SAMPLE_CENTER_BASE_URL 环境变量")
+    return settings.sample_center_base_url.rstrip("/")
+
+
+def _get_credentials() -> tuple[str, str]:
+    if not settings.sample_center_app_id or not settings.sample_center_app_secret:
+        raise ValueError("样本中心凭证未配置,请检查 SAMPLE_CENTER_APP_ID 和 SAMPLE_CENTER_APP_SECRET")
+    return settings.sample_center_app_id, settings.sample_center_app_secret
+
+
+def _check_token_valid() -> bool:
+    if not _token_cache.get("access_token"):
+        return False
+    expires_at = _token_cache.get("expires_at", 0)
+    return time.time() < expires_at - 300  # 提前 5 分钟过期
+
+
+async def get_token() -> str:
+    if _check_token_valid():
+        return _token_cache["access_token"]
+
+    app_id, app_secret = _get_credentials()
+    base_url = _get_base_url()
+
+    async with httpx.AsyncClient(timeout=30) as client:
+        resp = await client.post(
+            f"{base_url}/api/v1/auth/token",
+            json={"app_id": app_id, "app_secret": app_secret},
+        )
+        resp.raise_for_status()
+        body = resp.json()
+
+    if body.get("code") != "000000":
+        raise RuntimeError(f"获取样本中心 Token 失败: {body.get('message')}")
+
+    data = body["data"]
+    _token_cache["access_token"] = data["access_token"]
+    _token_cache["expires_in"] = data.get("expires_in", 7200)
+    _token_cache["expires_at"] = time.time() + data.get("expires_in", 7200)
+    _token_cache["token_type"] = data.get("token_type", "Bearer")
+
+    return data["access_token"]
+
+
+def _auth_headers() -> dict[str, str]:
+    app_id, _ = _get_credentials()
+    token = _token_cache.get("access_token", "")
+    return {
+        "Authorization": f"Bearer {token}",
+        "X-App-Id": app_id,
+        "Content-Type": "application/json",
+    }
+
+
+async def list_knowledge_bases(page: int = 1, page_size: int = 20) -> dict[str, Any]:
+    """获取知识库列表。"""
+    token = await get_token()
+    base_url = _get_base_url()
+
+    async with httpx.AsyncClient(timeout=30) as client:
+        resp = await client.get(
+            f"{base_url}/api/v1/knowledge-bases",
+            params={"page": page, "page_size": page_size},
+            headers=_auth_headers(),
+        )
+        resp.raise_for_status()
+        body = resp.json()
+
+    if body.get("code") != "000000":
+        raise RuntimeError(f"获取知识库列表失败: {body.get('message')}")
+
+    return body["data"]
+
+
+async def get_knowledge_base_detail(kb_id: str) -> dict[str, Any]:
+    """获取知识库详情。"""
+    token = await get_token()
+    base_url = _get_base_url()
+
+    async with httpx.AsyncClient(timeout=30) as client:
+        resp = await client.get(
+            f"{base_url}/api/v1/knowledge-bases/{kb_id}",
+            headers=_auth_headers(),
+        )
+        resp.raise_for_status()
+        body = resp.json()
+
+    if body.get("code") != "000000":
+        raise RuntimeError(f"获取知识库详情失败: {body.get('message')}")
+
+    return body["data"]
+
+
+async def batch_import(kb_id: str, parents: list[dict], children: list[dict] | None = None,
+                       callback_url: str | None = None) -> dict[str, Any]:
+    """提交批量入库任务。"""
+    import uuid
+
+    token = await get_token()
+    base_url = _get_base_url()
+    task_no = f"IMP{int(time.time())}{uuid.uuid4().hex[:8]}"
+
+    payload: dict[str, Any] = {
+        "task_no": task_no,
+        "parents": parents,
+    }
+    if children:
+        payload["children"] = children
+    if callback_url:
+        payload["callback_url"] = callback_url
+
+    async with httpx.AsyncClient(timeout=60) as client:
+        resp = await client.post(
+            f"{base_url}/api/v1/knowledge-bases/{kb_id}/batch-import",
+            json=payload,
+            headers=_auth_headers(),
+        )
+        resp.raise_for_status()
+        body = resp.json()
+
+    if body.get("code") != "000000":
+        raise RuntimeError(f"批量入库提交失败: {body.get('message')}")
+
+    return body["data"]
+
+
+async def query_import_task(task_id: str) -> dict[str, Any]:
+    """查询批量入库任务状态。"""
+    token = await get_token()
+    base_url = _get_base_url()
+
+    async with httpx.AsyncClient(timeout=30) as client:
+        resp = await client.get(
+            f"{base_url}/api/v1/knowledge-bases/batch-import/{task_id}",
+            headers=_auth_headers(),
+        )
+        resp.raise_for_status()
+        body = resp.json()
+
+    if body.get("code") != "000000":
+        raise RuntimeError(f"查询任务失败: {body.get('message')}")
+
+    return body["data"]
+
+
+async def import_kb_to_dataset(kb_id: str, kb_name: str) -> dict[str, Any]:
+    """从知识库导入数据:查询知识库详情,将数据转为训练格式并保存为数据集。
+
+    由于样本中心的批量入库是异步任务,这里采用直接查询知识库内容的方式。
+    先获取知识库详情,然后根据 metadata_schema 构建训练数据集。
+    """
+    kb_detail = await get_knowledge_base_detail(kb_id)
+
+    # 这里返回知识库信息,前端可据此展示给用户
+    # 实际的数据导入由批量入库 API 完成
+    return {
+        "kb_id": kb_id,
+        "kb_name": kb_name or kb_detail.get("name", ""),
+        "document_count": kb_detail.get("document_count", 0),
+        "metadata_schema": kb_detail.get("metadata_schema", []),
+        "parent_table": kb_detail.get("parent_table", ""),
+        "child_table": kb_detail.get("child_table", ""),
+    }

+ 5 - 0
backend/main.py

@@ -61,6 +61,7 @@ def create_app() -> FastAPI:
     from app.api import deployment as deployment_api
     from app.api import inference as inference_api
     from app.api import auth as auth_api
+    from app.api import sample_center as sample_center_api
     from app.core.auth import get_current_active_user
 
     # 认证路由(无 prefix,端点自带完整路径)
@@ -91,6 +92,10 @@ def create_app() -> FastAPI:
         inference_api.router, prefix="/api/v1/inference", tags=["inference"],
         dependencies=[Depends(get_current_active_user)],
     )
+    app.include_router(
+        sample_center_api.router, prefix="/api/v1/sample-center", tags=["sample-center"],
+        dependencies=[Depends(get_current_active_user)],
+    )
 
     # WebSocket
     from app.core.websocket import router as ws_router

+ 55 - 1
frontend/src/api/client.ts

@@ -157,6 +157,20 @@ const api = {
       apiFetch(`/api/v1/deployment/${id}/status`).then(r => r.json()) as Promise<DeployResponse>,
   },
 
+  // --- Sample Center ---
+  sampleCenter: {
+    listKnowledgeBases: (page = 1, page_size = 20) =>
+      apiFetch(`/api/v1/sample-center/knowledge-bases?page=${page}&page_size=${page_size}`)
+        .then(r => r.json()) as Promise<KnowledgeBaseListResponse>,
+    getKnowledgeBaseDetail: (kb_id: string) =>
+      apiFetch(`/api/v1/sample-center/knowledge-bases/${kb_id}`)
+        .then(r => r.json()) as Promise<KnowledgeBaseDetailResponse>,
+    importFromKnowledgeBase: (kb_id: string, kb_name = '') =>
+      apiFetch(`/api/v1/sample-center/knowledge-bases/${kb_id}/import?kb_name=${encodeURIComponent(kb_name)}`, {
+        method: 'POST',
+      }).then(r => r.json()) as Promise<KbImportResponse>,
+  },
+
   // --- Inference ---
   inference: {
     generate: (req: InferenceRequest) =>
@@ -328,4 +342,44 @@ interface InferenceResponse {
   error?: string
 }
 
-export type { ModelInfo, ModelTestRequest, ModelTestResponse, ModelDownloadResponse, DatasetInfo, DatasetDownloadResponse, DatasetPreview, DatasetValidation, TrainingJob, TrainingConfig, EvalConfig, EvalResult, DeployConfig, DeployResponse, AdapterInfo, InferenceRequest, InferenceResponse }
+interface MetadataSchemaField {
+  field_name_cn: string
+  field_name_en: string
+  field_type: string
+  description: string
+}
+
+interface KnowledgeBaseItem {
+  id: string
+  name: string
+  parent_table: string
+  child_table: string
+  document_count: number
+  status: number
+  created_at: string
+  created_by: string
+  metadata_schema: MetadataSchemaField[]
+}
+
+interface KnowledgeBaseListResponse {
+  total: number
+  page: number
+  page_size: number
+  items: KnowledgeBaseItem[]
+}
+
+interface KnowledgeBaseDetailResponse extends KnowledgeBaseItem {
+  description: string
+  updated_at: string
+}
+
+interface KbImportResponse {
+  kb_id: string
+  kb_name: string
+  document_count: number
+  metadata_schema: MetadataSchemaField[]
+  parent_table: string
+  child_table: string
+}
+
+export type { ModelInfo, ModelTestRequest, ModelTestResponse, ModelDownloadResponse, DatasetInfo, DatasetDownloadResponse, DatasetPreview, DatasetValidation, TrainingJob, TrainingConfig, EvalConfig, EvalResult, DeployConfig, DeployResponse, AdapterInfo, InferenceRequest, InferenceResponse, KnowledgeBaseItem, KnowledgeBaseListResponse, KnowledgeBaseDetailResponse, KbImportResponse }

+ 199 - 3
frontend/src/pages/Datasets.tsx

@@ -1,6 +1,6 @@
-import { useState, useEffect, useRef, memo } from 'react'
-import api, { DatasetInfo } from '../api/client'
-import { Database, Upload, Loader2 } from 'lucide-react'
+import { useState, useEffect, useRef, memo, useCallback } from 'react'
+import api, { DatasetInfo, KnowledgeBaseItem } from '../api/client'
+import { Database, Upload, Loader2, FolderOpen } from 'lucide-react'
 
 const DatasetRow = memo(function DatasetRow({ d, onPreview, onDelete }: {
   d: DatasetInfo
@@ -56,6 +56,15 @@ export function Datasets() {
   const [previewData, setPreviewData] = useState<{ columns: string[]; rows: { row_index: number; data: Record<string, unknown> }[] } | null>(null)
   const inputRef = useRef<HTMLInputElement>(null)
 
+  // Sample center modal state
+  const [showSampleCenter, setShowSampleCenter] = useState(false)
+  const [kbList, setKbList] = useState<KnowledgeBaseItem[]>([])
+  const [kbLoading, setKbLoading] = useState(false)
+  const [kbImporting, setKbImporting] = useState<string | null>(null)
+  const [kbStatus, setKbStatus] = useState('')
+  const [kbPage, setKbPage] = useState(1)
+  const [kbTotal, setKbTotal] = useState(0)
+
   useEffect(() => {
     fetchDatasets()
   }, [])
@@ -117,6 +126,34 @@ export function Datasets() {
     }
   }
 
+  const fetchKnowledgeBases = useCallback((page = 1) => {
+    setKbLoading(true)
+    api.sampleCenter.listKnowledgeBases(page, 20)
+      .then(res => {
+        setKbList(res.items)
+        setKbTotal(res.total)
+        setKbPage(res.page)
+      })
+      .catch(err => setKbStatus(`获取知识库列表失败: ${err.message}`))
+      .finally(() => setKbLoading(false))
+  }, [])
+
+  const handleImportFromKB = async (kb: KnowledgeBaseItem) => {
+    setKbImporting(kb.id)
+    setKbStatus(`正在导入 "${kb.name}" ...`)
+    try {
+      await api.sampleCenter.importFromKnowledgeBase(kb.id, kb.name)
+      setKbStatus(`"${kb.name}" 导入请求已提交,可在样本中心查看入库进度`)
+      // 刷新本地数据集列表
+      fetchDatasets()
+    } catch (err: unknown) {
+      const msg = err instanceof Error ? err.message : '导入失败'
+      setKbStatus(`导入失败: ${msg}`)
+    } finally {
+      setKbImporting(null)
+    }
+  }
+
   return (
     <div>
       <h1 style={{ margin: 0, fontSize: 22, fontWeight: 700 }}>数据集管理</h1>
@@ -200,6 +237,165 @@ export function Datasets() {
         }}>{dlStatus}</p>}
       </div>
 
+      {/* Sample Center section */}
+      <div style={{
+        marginTop: 24, background: '#fff', borderRadius: 10, padding: 20,
+        boxShadow: '0 1px 3px rgba(0,0,0,0.06)', border: '1px solid rgba(0,0,0,0.04)',
+      }}>
+        <h2 style={{ margin: '0 0 12px', fontSize: 15, fontWeight: 600 }}>样本中心</h2>
+        <p style={{ fontSize: 13, color: '#64748b', margin: '0 0 12px' }}>
+          从样本中心导入知识库数据作为训练数据集
+        </p>
+        <button
+          onClick={() => { setShowSampleCenter(true); fetchKnowledgeBases(1); }}
+          style={{
+            padding: '10px 20px', borderRadius: 8, border: 'none',
+            background: '#8b5cf6', color: '#fff', cursor: 'pointer', fontSize: 14, fontWeight: 600,
+          }}
+        >
+          <FolderOpen size={16} style={{ display: 'inline', verticalAlign: 'middle', marginRight: 4 }} />
+          从样本中心导入
+        </button>
+      </div>
+
+      {/* Sample Center Modal */}
+      {showSampleCenter && (
+        <div
+          style={{
+            position: 'fixed', inset: 0, background: 'rgba(0,0,0,0.4)',
+            display: 'flex', alignItems: 'center', justifyContent: 'center', zIndex: 1000,
+          }}
+          onClick={() => setShowSampleCenter(false)}
+        >
+          <div
+            onClick={e => e.stopPropagation()}
+            style={{
+              background: '#fff', borderRadius: 12, padding: 24, width: '90%', maxWidth: 700,
+              maxHeight: '80vh', overflow: 'auto', boxShadow: '0 20px 60px rgba(0,0,0,0.15)',
+            }}
+          >
+            <div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 16 }}>
+              <h2 style={{ margin: 0, fontSize: 17, fontWeight: 600 }}>样本中心 - 知识库列表</h2>
+              <button
+                onClick={() => setShowSampleCenter(false)}
+                style={{
+                  border: 'none', background: 'transparent', cursor: 'pointer', fontSize: 20,
+                  color: '#64748b', padding: '4px 8px', borderRadius: 4,
+                }}
+              >✕</button>
+            </div>
+
+            <p style={{ fontSize: 13, color: '#64748b', margin: '0 0 16px' }}>
+              选择要导入的知识库,数据将转为训练格式
+            </p>
+
+            {/* KB list */}
+            {kbLoading && (
+              <div style={{ textAlign: 'center', padding: 20, color: '#94a3b8' }}>
+                <Loader2 size={24} style={{ animation: 'lucide-spin 1s linear infinite' }} />
+                <div style={{ marginTop: 8, fontSize: 13 }}>加载中...</div>
+              </div>
+            )}
+
+            {!kbLoading && kbList.length === 0 && (
+              <div style={{ padding: 20, textAlign: 'center', color: '#94a3b8', fontSize: 14 }}>
+                暂无可用的知识库
+              </div>
+            )}
+
+            {!kbLoading && kbList.length > 0 && (
+              <div style={{ border: '1px solid #e2e8f0', borderRadius: 8, overflow: 'hidden' }}>
+                <table style={{ width: '100%', borderCollapse: 'collapse', fontSize: 13 }}>
+                  <thead>
+                    <tr style={{ background: '#f5f3ff', borderBottom: '2px solid #e2e8f0', textAlign: 'left' }}>
+                      <th style={{ padding: '10px 12px', fontSize: 12, color: '#64748b', fontWeight: 600 }}>名称</th>
+                      <th style={{ padding: '10px 12px', fontSize: 12, color: '#64748b', fontWeight: 600 }}>文档数</th>
+                      <th style={{ padding: '10px 12px', fontSize: 12, color: '#64748b', fontWeight: 600 }}>状态</th>
+                      <th style={{ padding: '10px 12px', fontSize: 12, color: '#64748b', fontWeight: 600 }}>字段</th>
+                      <th style={{ padding: '10px 12px', fontSize: 12, color: '#64748b', fontWeight: 600 }}>操作</th>
+                    </tr>
+                  </thead>
+                  <tbody>
+                    {kbList.map(kb => (
+                      <tr key={kb.id} style={{ borderBottom: '1px solid #f1f5f9' }}>
+                        <td style={{ padding: '10px 12px', fontWeight: 500 }}>{kb.name}</td>
+                        <td style={{ padding: '10px 12px', fontSize: 13 }}>{kb.document_count}</td>
+                        <td style={{ padding: '10px 12px', fontSize: 13 }}>
+                          <span style={{
+                            display: 'inline-block', padding: '2px 8px', borderRadius: 4, fontSize: 12,
+                            background: kb.status === 1 ? '#dcfce7' : '#f1f5f9',
+                            color: kb.status === 1 ? '#16a34a' : '#64748b',
+                          }}>
+                            {kb.status === 1 ? '启用' : '禁用'}
+                          </span>
+                        </td>
+                        <td style={{ padding: '10px 12px', fontSize: 12, color: '#64748b', maxWidth: 200 }}>
+                          {kb.metadata_schema.slice(0, 3).map(f => f.field_name_cn).join('、')}
+                          {kb.metadata_schema.length > 3 ? '...' : ''}
+                        </td>
+                        <td style={{ padding: '10px 12px' }}>
+                          <button
+                            onClick={() => handleImportFromKB(kb)}
+                            disabled={kbImporting === kb.id}
+                            style={{
+                              padding: '4px 12px', color: '#8b5cf6', border: '1px solid #8b5cf6',
+                              borderRadius: 6, background: kbImporting === kb.id ? '#f5f3ff' : 'transparent',
+                              cursor: kbImporting === kb.id ? 'not-allowed' : 'pointer',
+                              fontSize: 12, fontWeight: 500, opacity: kbImporting === kb.id ? 0.7 : 1,
+                            }}
+                          >
+                            {kbImporting === kb.id ? (
+                              <><Loader2 size={12} style={{ animation: 'lucide-spin 1s linear infinite', display: 'inline', verticalAlign: 'middle', marginRight: 4 }} />导入中</>
+                            ) : '导入'}
+                          </button>
+                        </td>
+                      </tr>
+                    ))}
+                  </tbody>
+                </table>
+              </div>
+            )}
+
+            {/* Pagination */}
+            {!kbLoading && kbTotal > 20 && (
+              <div style={{ display: 'flex', justifyContent: 'center', gap: 8, marginTop: 16, alignItems: 'center' }}>
+                <button
+                  disabled={kbPage <= 1}
+                  onClick={() => fetchKnowledgeBases(kbPage - 1)}
+                  style={{
+                    padding: '4px 12px', borderRadius: 6, border: '1px solid #cbd5e1',
+                    background: '#fff', cursor: kbPage <= 1 ? 'not-allowed' : 'pointer',
+                    opacity: kbPage <= 1 ? 0.5 : 1, fontSize: 13,
+                  }}
+                >上一页</button>
+                <span style={{ fontSize: 13, color: '#64748b' }}>
+                  第 {kbPage} 页 / 共 {Math.ceil(kbTotal / 20)} 页
+                </span>
+                <button
+                  disabled={kbPage * 20 >= kbTotal}
+                  onClick={() => fetchKnowledgeBases(kbPage + 1)}
+                  style={{
+                    padding: '4px 12px', borderRadius: 6, border: '1px solid #cbd5e1',
+                    background: '#fff', cursor: kbPage * 20 >= kbTotal ? 'not-allowed' : 'pointer',
+                    opacity: kbPage * 20 >= kbTotal ? 0.5 : 1, fontSize: 13,
+                  }}
+                >下一页</button>
+              </div>
+            )}
+
+            {/* Status */}
+            {kbStatus && (
+              <p style={{
+                marginTop: 12, padding: '8px 12px', borderRadius: 6, fontSize: 13,
+                background: kbStatus.includes('失败') ? '#fff1f2' : '#f0fdf4',
+                color: kbStatus.includes('失败') ? '#e11d48' : '#16a34a',
+                border: `1px solid ${kbStatus.includes('失败') ? '#fecdd3' : '#bbf7d0'}`,
+              }}>{kbStatus}</p>
+            )}
+          </div>
+        </div>
+      )}
+
       {/* Dataset list */}
       <div style={{ marginTop: 24 }}>
         <div style={{ display: 'flex', justifyContent: 'space-between', alignItems: 'center', marginBottom: 12 }}>

+ 4 - 3
frontend/src/pages/Training.tsx

@@ -43,10 +43,11 @@ const EPOCH_PRESETS = [
 const BATCH_SIZE_PRESETS = [
   { value: 1, label: '1 (显存受限)' },
   { value: 2, label: '2' },
-  { value: 4, label: '4 (推荐)' },
+  { value: 4, label: '4' },
   { value: 8, label: '8' },
-  { value: 16, label: '16' },
+  { value: 16, label: '16 (推荐)' },
   { value: 32, label: '32' },
+  { value: 64, label: '64' },
 ]
 
 const LR_PRESETS = [
@@ -316,7 +317,7 @@ export function Training() {
   const [taskType, setTaskType] = useState('sft')
   const [template, setTemplate] = useState('auto')
   const [epochs, setEpochs] = useState(3)
-  const [batchSize, setBatchSize] = useState(4)
+  const [batchSize, setBatchSize] = useState(16)
   const [lr, setLr] = useState('2e-4')
   const [loraR, setLoraR] = useState(16)
   const [seqLen, setSeqLen] = useState(2048)

+ 21 - 51
result.txt

@@ -1,51 +1,21 @@
-(base) [root@localhost ~]# docker exec finetune-trainer tail -n 50 /tmp/train_a26f4344-2575-48ca-b62a-8eca3c28fb05.log
-    self._run_epoch(
-  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1737, in _run_epoch
-    tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
-  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1909, in training_step
-    loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
-  File "/opt/conda/lib/python3.10/site-packages/transformers/trainer.py", line 1981, in compute_loss
-    outputs = model(**inputs)
-  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
-    return self._call_impl(*args, **kwargs)
-  File "/opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1784, in _call_impl
-    return forward_call(*args, **kwargs)
-  File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 195, in forward
-    return self.gather(outputs, self.output_device)
-  File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/data_parallel.py", line 218, in gather
-    return gather(outputs, output_device, dim=self.dim)
-  File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/scatter_gather.py", line 134, in gather
-    res = gather_map(outputs)
-  File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/scatter_gather.py", line 126, in gather_map
-    return type(out)((k, gather_map([d[k] for d in outputs])) for k in out)
-  File "<string>", line 8, in __init__
-  File "/opt/conda/lib/python3.10/site-packages/transformers/utils/generic.py", line 451, in __post_init__
-    for idx, element in enumerate(iterator):
-  File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/scatter_gather.py", line 126, in <genexpr>
-    return type(out)((k, gather_map([d[k] for d in outputs])) for k in out)
-  File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/scatter_gather.py", line 120, in gather_map
-    return Gather.apply(target_device, dim, *outputs)
-  File "/opt/conda/lib/python3.10/site-packages/torch/autograd/function.py", line 576, in apply
-    return super().apply(*args, **kwargs)  # type: ignore[misc]
-  File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/_functions.py", line 80, in forward
-    return comm.gather(inputs, ctx.dim, ctx.target_device)
-  File "/opt/conda/lib/python3.10/site-packages/torch/nn/parallel/comm.py", line 253, in gather
-    return torch._C._gather(tensors, dim, destination)
-torch.OutOfMemoryError: CUDA out of memory. Tried to allocate 1.49 GiB. GPU 0 has a total capacity of 63.78 GiB of which 500.44 MiB is free. Of the allocated memory 2.26 GiB is allocated by PyTorch, and 46.79 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
-{'loss': '2.906', 'grad_norm': '1.203', 'learning_rate': '2.799e-06', 'epoch': '0.002334'}
-{'loss': '2.906', 'grad_norm': '1.112', 'learning_rate': '5.91e-06', 'epoch': '0.004669'}
-{'loss': '2.889', 'grad_norm': '1.232', 'learning_rate': '9.02e-06', 'epoch': '0.007003'}
-{'loss': '2.821', 'grad_norm': '1.221', 'learning_rate': '1.213e-05', 'epoch': '0.009338'}
-{'loss': '2.826', 'grad_norm': '1.137', 'learning_rate': '1.524e-05', 'epoch': '0.01167'}
-{'loss': '2.758', 'grad_norm': '1.051', 'learning_rate': '1.835e-05', 'epoch': '0.01401'}
-{'loss': '2.674', 'grad_norm': '1.227', 'learning_rate': '2.146e-05', 'epoch': '0.01634'}
-{'loss': '2.635', 'grad_norm': '1.046', 'learning_rate': '2.457e-05', 'epoch': '0.01868'}
-{'loss': '2.595', 'grad_norm': '1.148', 'learning_rate': '2.768e-05', 'epoch': '0.02101'}
-{'loss': '2.541', 'grad_norm': '1.157', 'learning_rate': '3.079e-05', 'epoch': '0.02334'}
-{'loss': '2.509', 'grad_norm': '1.101', 'learning_rate': '3.39e-05', 'epoch': '0.02568'}
-{'loss': '2.523', 'grad_norm': '1.242', 'learning_rate': '3.701e-05', 'epoch': '0.02801'}
-{'loss': '2.475', 'grad_norm': '1.377', 'learning_rate': '4.012e-05', 'epoch': '0.03035'}
-{'loss': '2.461', 'grad_norm': '1.403', 'learning_rate': '4.323e-05', 'epoch': '0.03268'}
-{'loss': '2.419', 'grad_norm': '1.246', 'learning_rate': '4.635e-05', 'epoch': '0.03502'}
-{'loss': '2.427', 'grad_norm': '1.41', 'learning_rate': '4.946e-05', 'epoch': '0.03735'}
-  1%|▏         | 166/12852 [18:46<23:54:59,  6.79s/it]
+(base) [root@localhost ~]# docker exec finetune-trainer bash -c 'tail -n 20 $(ls -t /tmp/train_*.log | head -1)'
+[remote_train]   Preprocessing done, output: /root/Fine-tuning/backend/data/processed/fb18c7a8-e275-4014-b6a3-dea08f3f7adb_processed.jsonl
+[remote_train] Step 2: Loading model: Qwen/Qwen1.5-0.5B...
+[remote_train]   Quantization: None
+Loading weights: 100%|██████████| 291/291 [00:04<00:00, 59.89it/s] 
+[remote_train]   Model loaded successfully
+[remote_train] Step 3: Building PEFT config...
+[remote_train]   PEFT config built
+[remote_train] Step 4: Starting training...
+Map: 100%|██████████| 274147/274147 [00:15<00:00, 18259.13 examples/s]
+/opt/conda/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:1348: UserWarning: Model has `tie_word_embeddings=True` and a tied layer is part of the adapter, but `ensure_weight_tying` is not set to True. This can lead to complications, for example when merging the adapter or converting your model to formats other than safetensors. Check the discussion here: https://github.com/huggingface/peft/issues/2777
+  warnings.warn(msg)
+[transformers] warmup_ratio is deprecated and will be removed in v5.2. Use `warmup_steps` instead.
+/opt/conda/lib/python3.10/site-packages/torchvision/datapoints/__init__.py:12: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning().
+  warnings.warn(_BETA_TRANSFORMS_WARNING)
+/opt/conda/lib/python3.10/site-packages/torchvision/transforms/v2/__init__.py:54: UserWarning: The torchvision.datapoints and torchvision.transforms.v2 namespaces are still Beta. While we do not expect major breaking changes, some APIs may still change according to user feedback. Please submit any feedback you may have in this issue: https://github.com/pytorch/vision/issues/6753, and you can also check out https://github.com/pytorch/vision/issues/7319 to learn more about the APIs that we suspect might involve future changes. You can silence this warning by calling torchvision.disable_beta_transforms_warning().
+  warnings.warn(_BETA_TRANSFORMS_WARNING)
+trainable params: 5,593,088 || all params: 469,580,800 || trainable%: 1.1911
+  0%|          | 0/4284 [00:00<?, ?it/s]/opt/conda/lib/python3.10/site-packages/transformers/models/qwen2/modeling_qwen2.py:108: UserWarning: Attempting to run cuBLAS, but there was no current CUDA context! Attempting to set the primary context... (Triggered internally at /workspace/framework/mcPytorch/aten/src/ATen/cuda/CublasHandlePool.cpp:183.)
+  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
+ 19%|█▉        | 812/4284 [21:54<1:34:25,  1.63s/it](base) [root@localhost ~]#