Kaynağa Gözat

修复语法报错

lxylxy123321 2 gün önce
ebeveyn
işleme
28ee2591b2
1 değiştirilmiş dosya ile 77 ekleme ve 27 silme
  1. 77 27
      backend/app/engines/remote_train.py

+ 77 - 27
backend/app/engines/remote_train.py

@@ -218,21 +218,19 @@ def _patch_fla_shared_memory():
     反向传播时 chunk kernel 的 block size 为 64,需要约 106KB 共享内存,
     但沐曦/部分 NVIDIA GPU 硬件上限仅 64KB(65536 字节),导致 OutOfResources。
 
-    修复方式:在 fla 模块首次导入前,全面降低所有 block size 相关的值:
-    1. blockdim64 → blockdim32(kernel 函数名后缀)
-    2. 所有 = 64 的赋值/参数 → = 32(覆盖 BT/BK/BV/chunk_size 等变量名)
-    3. tl.constexpr 值为 128/256 的也降为 64
+    修复方式:精确替换 fla 库中控制 block size 的常量和 kernel 名称,
+    避免误改普通代码中的数字字面量。
     """
     try:
         import shutil
-        import site
 
-        fla_base = None
-        # 优先检查 conda 环境路径
+        # 定位 fla 包
         conda_path = '/opt/conda/lib/python3.10/site-packages/fla'
         if os.path.isdir(conda_path):
             fla_base = conda_path
         else:
+            import site
+            fla_base = None
             for sp in site.getsitepackages() + [site.getusersitepackages() if hasattr(site, 'getusersitepackages') else '']:
                 candidate = os.path.join(sp, 'fla')
                 if os.path.isdir(candidate):
@@ -245,11 +243,51 @@ def _patch_fla_shared_memory():
 
         _remote_log(f"fla package found at: {fla_base}")
 
+        # 检查 fla 源码是否被旧版补丁损坏(语法错误)
+        chunk_py = os.path.join(fla_base, 'ops', 'gated_delta_rule', 'chunk.py')
+        source_corrupted = False
+        if os.path.exists(chunk_py):
+            try:
+                with open(chunk_py, 'r') as f:
+                    compile(f.read(), chunk_py, 'exec')
+            except SyntaxError as e:
+                source_corrupted = True
+                _remote_log(f"fla source corrupted (SyntaxError: {e}), will reinstall...")
+
         # 幂等检查
         marker_path = os.path.join(fla_base, '_PATCHED_SM32')
-        if os.path.exists(marker_path):
-            _remote_log("fla shared memory patch already applied (marker found), skipping")
-            return
+        if os.path.exists(marker_path) and not source_corrupted:
+            # 检查标记版本:v1 是旧版补丁(用激进正则,已污染源码),需要重装后重新打补丁
+            with open(marker_path) as mf:
+                marker_content = mf.read()
+            if 'v2' in marker_content:
+                _remote_log("fla shared memory patch v2 already applied, skipping")
+                to_remove = [k for k in sys.modules if k.startswith('fla')]
+                for k in to_remove:
+                    del sys.modules[k]
+                return
+            else:
+                source_corrupted = True
+                _remote_log("Old patch v1 detected, will reinstall fla...")
+
+        if source_corrupted:
+            _remote_log("Reinstalling fla to restore clean source...")
+            import subprocess
+            # 尝试多个可能的包名
+            for pkg_name in ['fla', 'flash-linear-attention']:
+                result = subprocess.run(
+                    [sys.executable, '-m', 'pip', 'install', '--force-reinstall', '--no-deps', pkg_name],
+                    capture_output=True, text=True, timeout=120,
+                )
+                if result.returncode == 0:
+                    _remote_log(f"fla reinstalled successfully via '{pkg_name}'")
+                    break
+                else:
+                    _remote_log(f"pip install '{pkg_name}' failed: {result.stderr[:200]}")
+            # 清理旧标记
+            if os.path.exists(marker_path):
+                os.remove(marker_path)
+            _remote_log("Reapplying patch v2...")
 
         patched_files = []
 
@@ -264,26 +302,38 @@ def _patch_fla_shared_memory():
                     c = original
                     changes = []
 
-                    # 1. blockdim64 → blockdim32(kernel 函数名后缀)
+                    # 1. kernel 函数名后缀: blockdim64 → blockdim32
                     if 'blockdim64' in c:
                         c = c.replace('blockdim64', 'blockdim32')
                         changes.append('blockdim64->blockdim32')
 
-                    # 2. = 64 赋值/参数 → = 32(覆盖 BT=64, BK=64, BV=64, chunk_size=64 等)
-                    def _r64(m):
-                        return f'{m.group(1)}= 32'
-                    new_c = re.sub(r'([=:])\s*64\b(?!\d)', _r64, c)
+                    # 2. 精确匹配 fla 中常见的 block size 变量赋值
+                    #    BT = 64, BK = 64, BV = 64, chunk_size = 64, BLOCK_SIZE = 64 等
+                    #    用 \b 匹配完整变量名,避免误改其他代码
+                    for var in ['BT', 'BK', 'BV', 'chunk_size', 'BLOCK_SIZE',
+                                'BLOCK_M', 'BLOCK_N', 'BLOCK_K', 'BLOCK_V',
+                                'block_size', 'block_m', 'block_n', 'block_k', 'block_v']:
+                        pattern = rf'\b{var}\s*=\s*64\b'
+                        replacement = f'{var} = 32'
+                        new_c = re.sub(pattern, replacement, c)
+                        if new_c != c:
+                            changes.append(f'{var}=64->32')
+                            c = new_c
+
+                    # 3. Triton autotune 装饰器中的 configs 参数值
+                    #    例如: configs=[..., 64, ...] 或 tl.constexpr = 64
+                    #    只替换 tl.constexpr = 64 的情况
+                    pattern = r'tl\.constexpr\s*=\s*64\b'
+                    new_c = re.sub(pattern, 'tl.constexpr = 32', c)
                     if new_c != c:
-                        changes.append('=64 -> =32')
+                        changes.append('tl.constexpr 64->32')
                         c = new_c
 
-                    # 3. tl.constexpr = 128/256 → = 64(进一步降低大值)
-                    def _r_large(m):
-                        val = int(m.group(1))
-                        return f'tl.constexpr = {val // 2}'
-                    new_c = re.sub(r'tl\.constexpr\s*=\s*(128|256)\b', _r_large, c)
+                    # 4. num_stages 降低(减少流水线阶段,进一步降低共享内存)
+                    pattern = r'num_stages\s*=\s*([3-9]|[1-9]\d+)'
+                    new_c = re.sub(pattern, 'num_stages=1', c)
                     if new_c != c:
-                        changes.append('constexpr 128/256 halved')
+                        changes.append('num_stages->1')
                         c = new_c
 
                     if c != original:
@@ -294,23 +344,23 @@ def _patch_fla_shared_memory():
                     _remote_log(f"  Warning: failed to patch {fpath}: {e}")
                     continue
 
-        # 清理 __pycache__,确保下次 import 读新源码
+        # 清理 __pycache__
         cache_count = 0
         for root, dirs, files in os.walk(fla_base):
             if '__pycache__' in dirs:
                 shutil.rmtree(os.path.join(root, '__pycache__'), ignore_errors=True)
                 cache_count += 1
 
-        # 清除已缓存的 fla 模块,强制重新导入
+        # 清除已缓存的 fla 模块
         to_remove = [k for k in sys.modules if k.startswith('fla')]
         for k in to_remove:
             del sys.modules[k]
 
-        # 写入标记文件,下次运行时跳过(幂等)
+        # 写入标记文件(幂等),包含版本号 v2
         with open(marker_path, 'w') as f:
-            f.write(f"patched at {datetime.now(timezone.utc).isoformat()}\n")
+            f.write(f"v2 patched at {datetime.now(timezone.utc).isoformat()}\n")
 
-        _remote_log(f"fla shared memory patch done: {len(patched_files)} files patched, "
+        _remote_log(f"fla shared memory patch done: {len(patched_files)} files, "
                     f"{cache_count} caches cleared, {len(to_remove)} modules evicted")
         for pf in patched_files:
             _remote_log(f"  patched: {pf}")