Просмотр исходного кода

修复await字段缺失,修复adapter问题

lxylxy123321 2 дней назад
Родитель
Сommit
a267ec9e0c

+ 1 - 1
backend/app/services/dataset_service.py

@@ -91,7 +91,7 @@ async def download_dataset(req: DatasetDownloadRequest) -> DatasetDownloadRespon
 
     # 注册并启动
     background_task_manager.register_task(task_id, "dataset_download", {"dataset_id": req.dataset_id})
-    background_task_manager.run(
+    await background_task_manager.run(
         task_id, "dataset_download", _execute_dataset_download(task_id, req)
     )
 

+ 1 - 1
backend/app/services/deploy_service.py

@@ -32,7 +32,7 @@ async def export_adapter(job_id: str, config: dict[str, Any]) -> dict[str, Any]:
 
     # 注册并启动
     background_task_manager.register_task(task_id, "deployment", {"job_id": job_id})
-    background_task_manager.run(
+    await background_task_manager.run(
         task_id, "deployment", _execute_export(task_id, job_id, merge_with_base, export_format)
     )
 

+ 1 - 1
backend/app/services/eval_service.py

@@ -30,7 +30,7 @@ async def run_evaluation(job_id: str, config: dict[str, Any]) -> dict[str, Any]:
 
     # 注册并启动
     background_task_manager.register_task(eval_id, "evaluation", {"job_id": job_id})
-    background_task_manager.run(
+    await background_task_manager.run(
         eval_id, "evaluation", _execute_evaluation(eval_id, job_id, config)
     )
 

+ 33 - 13
backend/app/services/inference_service.py

@@ -131,25 +131,45 @@ def _get_base_model_id(adapter_path: str) -> str | None:
 
 
 async def get_available_adapters() -> list[dict[str, Any]]:
-    """列出所有已训练的 adapter。"""
+    """列出所有已训练完成的 adapter(仅显示 status=completed 的任务)。"""
+    from app.core.db import async_session, TrainingJobModel
+    from sqlalchemy import select
+
+    # 查询数据库中训练完成的任务
+    async with async_session() as session:
+        result = await session.execute(
+            select(TrainingJobModel).where(TrainingJobModel.status == "completed")
+        )
+        completed_jobs = {job.id: job for job in result.scalars().all()}
+
+    if not completed_jobs:
+        return []
+
     adapters_dir = settings.adapters_dir
     if not adapters_dir.exists():
         return []
 
     result = []
-    for d in sorted(adapters_dir.iterdir()):
-        if not d.is_dir():
+    for job_id, job in sorted(completed_jobs.items(), key=lambda x: x[1].created_at, reverse=True):
+        adapter_dir = adapters_dir / job_id
+        if not adapter_dir.is_dir():
             continue
-        adapter_config = d / "adapter_config.json"
-        if adapter_config.exists():
-            with open(adapter_config) as f:
-                cfg = json.load(f)
-            result.append({
-                "id": d.name,
-                "path": str(d),
-                "base_model": cfg.get("base_model_name_or_path", "unknown"),
-                "peft_type": cfg.get("peft_type", "unknown"),
-            })
+        adapter_config = adapter_dir / "adapter_config.json"
+        if not adapter_config.exists():
+            continue
+
+        with open(adapter_config) as f:
+            cfg = json.load(f)
+        result.append({
+            "id": job_id,
+            "path": str(adapter_dir),
+            "base_model": cfg.get("base_model_name_or_path", "unknown"),
+            "peft_type": cfg.get("peft_type", "unknown"),
+            "model_id": job.model_id,
+            "peft_method": job.peft_method,
+            "task_type": job.task_type,
+            "created_at": job.created_at.isoformat() if job.created_at else None,
+        })
     return result
 
 

+ 1 - 1
backend/app/services/model_service.py

@@ -68,7 +68,7 @@ async def download_model(model_id: str, use_modelscope: bool = False) -> dict[st
 
     # 注册并启动
     background_task_manager.register_task(task_id, "model_download", {"model_id": model_id})
-    background_task_manager.run(
+    await background_task_manager.run(
         task_id, "model_download", _execute_model_download(task_id, model_id, use_modelscope)
     )
 

+ 4 - 0
frontend/src/api/client.ts

@@ -357,6 +357,10 @@ interface AdapterInfo {
   path: string
   base_model: string
   peft_type: string
+  model_id?: string
+  peft_method?: string
+  task_type?: string
+  created_at?: string
 }
 
 interface InferenceRequest {

+ 4 - 2
frontend/src/pages/Inference.tsx

@@ -130,7 +130,7 @@ export function Inference() {
             background: '#f0fdfa', borderRadius: 8, border: '1px solid #ccfbf1',
           }}>
             <div style={{ marginBottom: 6 }}><MessageSquare size={24} color="#94a3b8" strokeWidth={1.5} /></div>
-            暂无可用的 adapter,请先完成训练任务
+            暂无可用的 adapter,请先完成训练任务并确保状态为「已完成」
           </div>
         ) : (
           <select
@@ -145,7 +145,9 @@ export function Inference() {
             onBlur={e => { e.currentTarget.style.borderColor = '#cbd5e1' }}
           >
             {adapters.map(a => (
-              <option key={a.id} value={a.id}>{a.id} — {a.base_model} ({a.peft_type})</option>
+              <option key={a.id} value={a.id}>
+                {a.model_id || a.base_model} | {(a.peft_method || a.peft_type).toUpperCase()} | {a.id.slice(0, 8)}...
+              </option>
             ))}
           </select>
         )}