lxylxy123321 2 дней назад
Родитель
Сommit
ec86921fdc
4 измененных файлов с 166 добавлено и 50 удалено
  1. 1 1
      backend/.env
  2. 1 1
      backend/.env.docker
  3. 96 42
      frontend/src/api/websocket.ts
  4. 68 6
      frontend/src/pages/Training.tsx

+ 1 - 1
backend/.env

@@ -7,7 +7,7 @@ BACKEND_HOST=0.0.0.0
 BACKEND_PORT=8010
 BACKEND_PORT=8010
 BACKEND_ENV=production
 BACKEND_ENV=production
 BACKEND_LOG_LEVEL=INFO
 BACKEND_LOG_LEVEL=INFO
-BACKEND_CORS_ORIGINS=http://192.168.91.253:5173
+BACKEND_CORS_ORIGINS=http://192.168.91.253:5173,http://192.168.92.151:3000,http://192.168.92.151
 
 
 # 数据库
 # 数据库
 DATABASE_URL=sqlite+aiosqlite:///root/Fine-tuning/backend/data/finetuning.db
 DATABASE_URL=sqlite+aiosqlite:///root/Fine-tuning/backend/data/finetuning.db

+ 1 - 1
backend/.env.docker

@@ -3,7 +3,7 @@ BACKEND_HOST=0.0.0.0
 BACKEND_PORT=8010
 BACKEND_PORT=8010
 BACKEND_ENV=production
 BACKEND_ENV=production
 BACKEND_LOG_LEVEL=INFO
 BACKEND_LOG_LEVEL=INFO
-BACKEND_CORS_ORIGINS=http://localhost:3000
+BACKEND_CORS_ORIGINS=http://localhost:3000,http://192.168.92.151:3000,http://192.168.92.151
 
 
 # PostgreSQL 数据库
 # PostgreSQL 数据库
 DATABASE_URL=postgresql+asyncpg://finetune:finetune123@postgres:5432/finetuning
 DATABASE_URL=postgresql+asyncpg://finetune:finetune123@postgres:5432/finetuning

+ 96 - 42
frontend/src/api/websocket.ts

@@ -1,76 +1,130 @@
+type MessageHandler = (msg: Record<string, unknown>) => void
+
+interface JobConnection {
+  ws: WebSocket
+  reconnectTimer: ReturnType<typeof setTimeout> | null
+  intentionalClose: boolean
+}
+
 class WSManager {
 class WSManager {
-  private ws: WebSocket | null = null
-  private handlers: Map<string, Set<(msg: Record<string, unknown>) => void>> = new Map()
-  private reconnectTimer: ReturnType<typeof setTimeout> | null = null
-  private intentionalClose = false
-
-  connect(baseUrl?: string) {
-    if (this.ws) return
-    this.intentionalClose = false
-    const url = baseUrl || (import.meta.env.VITE_WS_BASE_URL as string) || '/ws'
-    let wsUrl = url.startsWith('ws') ? url : `${window.location.protocol === 'https:' ? 'wss://' : 'ws://'}${window.location.host}${url}`
-    const token = localStorage.getItem('token')
-    if (token) {
-      wsUrl += wsUrl.includes('?') ? '&' : '?'
-      wsUrl += `token=${encodeURIComponent(token)}`
+  private connections: Map<string, JobConnection> = new Map()
+  private handlers: Map<string, Set<MessageHandler>> = new Map()
+
+  /**
+   * 为指定 job 建立 WebSocket 连接(/ws/training/{jobId})。
+   * 如果该 job 已有连接则跳过。
+   */
+  connect(jobId: string) {
+    if (this.connections.has(jobId)) return
+
+    const conn: JobConnection = {
+      ws: null as unknown as WebSocket,
+      reconnectTimer: null,
+      intentionalClose: false,
     }
     }
 
 
+    const protocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://'
+    const wsUrl = `${protocol}${window.location.host}/ws/training/${encodeURIComponent(jobId)}`
+    const token = localStorage.getItem('token')
+    const finalUrl = token ? `${wsUrl}?token=${encodeURIComponent(token)}` : wsUrl
+
     try {
     try {
-      this.ws = new WebSocket(wsUrl)
+      conn.ws = new WebSocket(finalUrl)
     } catch {
     } catch {
-      this.scheduleReconnect()
+      this.scheduleReconnect(jobId)
       return
       return
     }
     }
+    this.connections.set(jobId, conn)
 
 
-    this.ws.onopen = () => {
-      console.log('[WS] Connected')
+    conn.ws.onopen = () => {
+      console.log(`[WS] Connected for job ${jobId.slice(0, 8)}...`)
     }
     }
 
 
-    this.ws.onmessage = (event) => {
+    conn.ws.onmessage = (event) => {
       try {
       try {
         const msg = JSON.parse(event.data) as Record<string, unknown>
         const msg = JSON.parse(event.data) as Record<string, unknown>
-        this.handlers.get(msg.job_id as string)?.forEach(h => h(msg))
+        const msgJobId = (msg.job_id as string) || jobId
+        this.handlers.get(msgJobId)?.forEach(h => h(msg))
         this.handlers.get('*')?.forEach(h => h(msg))
         this.handlers.get('*')?.forEach(h => h(msg))
       } catch {
       } catch {
         // ignore non-JSON messages
         // ignore non-JSON messages
       }
       }
     }
     }
 
 
-    this.ws.onclose = () => {
-      this.ws = null
-      if (!this.intentionalClose) {
-        this.scheduleReconnect()
+    conn.ws.onclose = () => {
+      this.connections.delete(jobId)
+      if (!conn.intentionalClose) {
+        this.scheduleReconnect(jobId)
       }
       }
     }
     }
 
 
-    this.ws.onerror = () => {
-      this.ws?.close()
+    conn.ws.onerror = () => {
+      conn.ws?.close()
     }
     }
   }
   }
 
 
-  private scheduleReconnect() {
-    if (this.reconnectTimer) return
-    this.reconnectTimer = setTimeout(() => {
-      this.reconnectTimer = null
-      this.connect()
-    }, 3000)
+  /** 是否已有该 job 的连接或正在重连 */
+  isConnected(jobId: string): boolean {
+    return this.connections.has(jobId)
   }
   }
 
 
-  subscribe(jobId: string, handler: (msg: Record<string, unknown>) => void): () => void {
+  /**
+   * 断开指定 job 的 WebSocket 连接。
+   */
+  disconnect(jobId: string) {
+    const conn = this.connections.get(jobId)
+    if (!conn) return
+    conn.intentionalClose = true
+    if (conn.reconnectTimer) {
+      clearTimeout(conn.reconnectTimer)
+      conn.reconnectTimer = null
+    }
+    if (conn.ws) {
+      conn.ws.close()
+    }
+    this.connections.delete(jobId)
+  }
+
+  /**
+   * 断开所有 WebSocket 连接,清理所有 handler。
+   */
+  disconnectAll() {
+    for (const jobId of [...this.connections.keys()]) {
+      this.disconnect(jobId)
+    }
+    this.handlers.clear()
+  }
+
+  /**
+   * 订阅指定 job 的 WebSocket 消息。
+   * 返回取消订阅的函数。
+   */
+  subscribe(jobId: string, handler: MessageHandler): () => void {
     if (!this.handlers.has(jobId)) this.handlers.set(jobId, new Set())
     if (!this.handlers.has(jobId)) this.handlers.set(jobId, new Set())
     this.handlers.get(jobId)!.add(handler)
     this.handlers.get(jobId)!.add(handler)
     return () => this.handlers.get(jobId)?.delete(handler)
     return () => this.handlers.get(jobId)?.delete(handler)
   }
   }
 
 
-  disconnect() {
-    this.intentionalClose = true
-    if (this.reconnectTimer) {
-      clearTimeout(this.reconnectTimer)
-      this.reconnectTimer = null
-    }
-    if (this.ws) {
-      this.ws.close()
-      this.ws = null
+  private scheduleReconnect(jobId: string) {
+    const conn = this.connections.get(jobId)
+    if (conn?.reconnectTimer) return
+    const timer = setTimeout(() => {
+      if (conn) conn.reconnectTimer = null
+      // 仅在 handler 仍在订阅时才重连(说明页面还关心这个 job)
+      if (this.handlers.get(jobId)?.size) {
+        this.connections.delete(jobId)
+        this.connect(jobId)
+      }
+    }, 3000)
+    if (conn) {
+      conn.reconnectTimer = timer
+    } else {
+      // conn 已删除,用临时对象跟踪 timer
+      this.connections.set(jobId, {
+        ws: null as unknown as WebSocket,
+        reconnectTimer: timer,
+        intentionalClose: false,
+      })
     }
     }
   }
   }
 }
 }

+ 68 - 6
frontend/src/pages/Training.tsx

@@ -348,12 +348,10 @@ export function Training() {
 
 
   useEffect(() => { fetchOptions() }, [fetchOptions])
   useEffect(() => { fetchOptions() }, [fetchOptions])
 
 
-  useEffect(() => {
-    wsManager.connect()
-    return () => wsManager.disconnect()
-  }, [])
-
   const jobsRef = useRef<TrainingJob[]>([])
   const jobsRef = useRef<TrainingJob[]>([])
+  // 跟踪已建立 WS 连接的 job 和对应的取消订阅函数
+  const wsConnectedRef = useRef<Set<string>>(new Set())
+  const wsUnsubsRef = useRef<Map<string, () => void>>(new Map())
 
 
   const fetchJobs = () => {
   const fetchJobs = () => {
     setLoading(true)
     setLoading(true)
@@ -364,20 +362,84 @@ export function Training() {
           setJobs(newJobs)
           setJobs(newJobs)
           jobsRef.current = newJobs
           jobsRef.current = newJobs
         }
         }
+        syncWsConnections(newJobs)
       })
       })
       .catch(() => {
       .catch(() => {
         if (jobsRef.current.length > 0) {
         if (jobsRef.current.length > 0) {
           setJobs([])
           setJobs([])
           jobsRef.current = []
           jobsRef.current = []
         }
         }
+        syncWsConnections([])
       })
       })
       .finally(() => setLoading(false))
       .finally(() => setLoading(false))
   }
   }
 
 
+  // 根据 job 状态同步 WebSocket 连接:training/preprocessing 连接,其他断开
+  const ACTIVE_STATUSES = new Set(['training', 'preprocessing'])
+
+  function syncWsConnections(currentJobs: TrainingJob[]) {
+    const activeIds = new Set(currentJobs.filter(j => ACTIVE_STATUSES.has(j.status)).map(j => j.id))
+
+    // 为新的活跃 job 建立连接并订阅
+    for (const jobId of activeIds) {
+      if (wsConnectedRef.current.has(jobId)) continue
+      wsConnectedRef.current.add(jobId)
+      wsManager.connect(jobId)
+      const unsub = wsManager.subscribe(jobId, (msg) => {
+        if (msg.type === 'progress') {
+          setJobs(prev => prev.map(j =>
+            j.id === jobId
+              ? {
+                  ...j,
+                  current_step: (msg.step as number) ?? j.current_step,
+                  current_epoch: (msg.epoch as number) ?? j.current_epoch,
+                  total_steps: (msg.total_steps as number) || j.total_steps,
+                  loss: (msg.loss as number) ?? j.loss,
+                  progress: j.total_steps
+                    ? (((msg.step as number) ?? j.current_step) / (j.total_steps || 1)) * 100
+                    : j.progress,
+                }
+              : j
+          ))
+        } else if (msg.type === 'completed') {
+          setJobs(prev => prev.map(j =>
+            j.id === jobId ? { ...j, status: 'completed', progress: 100 } : j
+          ))
+          cleanupJobWs(jobId)
+        } else if (msg.type === 'error') {
+          setJobs(prev => prev.map(j =>
+            j.id === jobId
+              ? { ...j, status: 'failed', error_message: (msg.message as string) ?? '训练失败' }
+              : j
+          ))
+          cleanupJobWs(jobId)
+        }
+      })
+      wsUnsubsRef.current.set(jobId, unsub)
+    }
+
+    // 断开已完成/失败的 job 的连接
+    for (const jobId of [...wsConnectedRef.current]) {
+      if (!activeIds.has(jobId)) {
+        cleanupJobWs(jobId)
+      }
+    }
+  }
+
+  function cleanupJobWs(jobId: string) {
+    wsUnsubsRef.current.get(jobId)?.()
+    wsUnsubsRef.current.delete(jobId)
+    wsConnectedRef.current.delete(jobId)
+    wsManager.disconnect(jobId)
+  }
+
   useEffect(() => {
   useEffect(() => {
     fetchJobs()
     fetchJobs()
     const interval = setInterval(fetchJobs, 5000)
     const interval = setInterval(fetchJobs, 5000)
-    return () => clearInterval(interval)
+    return () => {
+      clearInterval(interval)
+      wsManager.disconnectAll()
+    }
   }, [])
   }, [])
 
 
   const handleCreate = () => {
   const handleCreate = () => {