|
@@ -218,21 +218,19 @@ def _patch_fla_shared_memory():
|
|
|
反向传播时 chunk kernel 的 block size 为 64,需要约 106KB 共享内存,
|
|
反向传播时 chunk kernel 的 block size 为 64,需要约 106KB 共享内存,
|
|
|
但沐曦/部分 NVIDIA GPU 硬件上限仅 64KB(65536 字节),导致 OutOfResources。
|
|
但沐曦/部分 NVIDIA GPU 硬件上限仅 64KB(65536 字节),导致 OutOfResources。
|
|
|
|
|
|
|
|
- 修复方式:在 fla 模块首次导入前,全面降低所有 block size 相关的值:
|
|
|
|
|
- 1. blockdim64 → blockdim32(kernel 函数名后缀)
|
|
|
|
|
- 2. 所有 = 64 的赋值/参数 → = 32(覆盖 BT/BK/BV/chunk_size 等变量名)
|
|
|
|
|
- 3. tl.constexpr 值为 128/256 的也降为 64
|
|
|
|
|
|
|
+ 修复方式:精确替换 fla 库中控制 block size 的常量和 kernel 名称,
|
|
|
|
|
+ 避免误改普通代码中的数字字面量。
|
|
|
"""
|
|
"""
|
|
|
try:
|
|
try:
|
|
|
import shutil
|
|
import shutil
|
|
|
- import site
|
|
|
|
|
|
|
|
|
|
- fla_base = None
|
|
|
|
|
- # 优先检查 conda 环境路径
|
|
|
|
|
|
|
+ # 定位 fla 包
|
|
|
conda_path = '/opt/conda/lib/python3.10/site-packages/fla'
|
|
conda_path = '/opt/conda/lib/python3.10/site-packages/fla'
|
|
|
if os.path.isdir(conda_path):
|
|
if os.path.isdir(conda_path):
|
|
|
fla_base = conda_path
|
|
fla_base = conda_path
|
|
|
else:
|
|
else:
|
|
|
|
|
+ import site
|
|
|
|
|
+ fla_base = None
|
|
|
for sp in site.getsitepackages() + [site.getusersitepackages() if hasattr(site, 'getusersitepackages') else '']:
|
|
for sp in site.getsitepackages() + [site.getusersitepackages() if hasattr(site, 'getusersitepackages') else '']:
|
|
|
candidate = os.path.join(sp, 'fla')
|
|
candidate = os.path.join(sp, 'fla')
|
|
|
if os.path.isdir(candidate):
|
|
if os.path.isdir(candidate):
|
|
@@ -245,11 +243,51 @@ def _patch_fla_shared_memory():
|
|
|
|
|
|
|
|
_remote_log(f"fla package found at: {fla_base}")
|
|
_remote_log(f"fla package found at: {fla_base}")
|
|
|
|
|
|
|
|
|
|
+ # 检查 fla 源码是否被旧版补丁损坏(语法错误)
|
|
|
|
|
+ chunk_py = os.path.join(fla_base, 'ops', 'gated_delta_rule', 'chunk.py')
|
|
|
|
|
+ source_corrupted = False
|
|
|
|
|
+ if os.path.exists(chunk_py):
|
|
|
|
|
+ try:
|
|
|
|
|
+ with open(chunk_py, 'r') as f:
|
|
|
|
|
+ compile(f.read(), chunk_py, 'exec')
|
|
|
|
|
+ except SyntaxError as e:
|
|
|
|
|
+ source_corrupted = True
|
|
|
|
|
+ _remote_log(f"fla source corrupted (SyntaxError: {e}), will reinstall...")
|
|
|
|
|
+
|
|
|
# 幂等检查
|
|
# 幂等检查
|
|
|
marker_path = os.path.join(fla_base, '_PATCHED_SM32')
|
|
marker_path = os.path.join(fla_base, '_PATCHED_SM32')
|
|
|
- if os.path.exists(marker_path):
|
|
|
|
|
- _remote_log("fla shared memory patch already applied (marker found), skipping")
|
|
|
|
|
- return
|
|
|
|
|
|
|
+ if os.path.exists(marker_path) and not source_corrupted:
|
|
|
|
|
+ # 检查标记版本:v1 是旧版补丁(用激进正则,已污染源码),需要重装后重新打补丁
|
|
|
|
|
+ with open(marker_path) as mf:
|
|
|
|
|
+ marker_content = mf.read()
|
|
|
|
|
+ if 'v2' in marker_content:
|
|
|
|
|
+ _remote_log("fla shared memory patch v2 already applied, skipping")
|
|
|
|
|
+ to_remove = [k for k in sys.modules if k.startswith('fla')]
|
|
|
|
|
+ for k in to_remove:
|
|
|
|
|
+ del sys.modules[k]
|
|
|
|
|
+ return
|
|
|
|
|
+ else:
|
|
|
|
|
+ source_corrupted = True
|
|
|
|
|
+ _remote_log("Old patch v1 detected, will reinstall fla...")
|
|
|
|
|
+
|
|
|
|
|
+ if source_corrupted:
|
|
|
|
|
+ _remote_log("Reinstalling fla to restore clean source...")
|
|
|
|
|
+ import subprocess
|
|
|
|
|
+ # 尝试多个可能的包名
|
|
|
|
|
+ for pkg_name in ['fla', 'flash-linear-attention']:
|
|
|
|
|
+ result = subprocess.run(
|
|
|
|
|
+ [sys.executable, '-m', 'pip', 'install', '--force-reinstall', '--no-deps', pkg_name],
|
|
|
|
|
+ capture_output=True, text=True, timeout=120,
|
|
|
|
|
+ )
|
|
|
|
|
+ if result.returncode == 0:
|
|
|
|
|
+ _remote_log(f"fla reinstalled successfully via '{pkg_name}'")
|
|
|
|
|
+ break
|
|
|
|
|
+ else:
|
|
|
|
|
+ _remote_log(f"pip install '{pkg_name}' failed: {result.stderr[:200]}")
|
|
|
|
|
+ # 清理旧标记
|
|
|
|
|
+ if os.path.exists(marker_path):
|
|
|
|
|
+ os.remove(marker_path)
|
|
|
|
|
+ _remote_log("Reapplying patch v2...")
|
|
|
|
|
|
|
|
patched_files = []
|
|
patched_files = []
|
|
|
|
|
|
|
@@ -264,26 +302,38 @@ def _patch_fla_shared_memory():
|
|
|
c = original
|
|
c = original
|
|
|
changes = []
|
|
changes = []
|
|
|
|
|
|
|
|
- # 1. blockdim64 → blockdim32(kernel 函数名后缀)
|
|
|
|
|
|
|
+ # 1. kernel 函数名后缀: blockdim64 → blockdim32
|
|
|
if 'blockdim64' in c:
|
|
if 'blockdim64' in c:
|
|
|
c = c.replace('blockdim64', 'blockdim32')
|
|
c = c.replace('blockdim64', 'blockdim32')
|
|
|
changes.append('blockdim64->blockdim32')
|
|
changes.append('blockdim64->blockdim32')
|
|
|
|
|
|
|
|
- # 2. = 64 赋值/参数 → = 32(覆盖 BT=64, BK=64, BV=64, chunk_size=64 等)
|
|
|
|
|
- def _r64(m):
|
|
|
|
|
- return f'{m.group(1)}= 32'
|
|
|
|
|
- new_c = re.sub(r'([=:])\s*64\b(?!\d)', _r64, c)
|
|
|
|
|
|
|
+ # 2. 精确匹配 fla 中常见的 block size 变量赋值
|
|
|
|
|
+ # BT = 64, BK = 64, BV = 64, chunk_size = 64, BLOCK_SIZE = 64 等
|
|
|
|
|
+ # 用 \b 匹配完整变量名,避免误改其他代码
|
|
|
|
|
+ for var in ['BT', 'BK', 'BV', 'chunk_size', 'BLOCK_SIZE',
|
|
|
|
|
+ 'BLOCK_M', 'BLOCK_N', 'BLOCK_K', 'BLOCK_V',
|
|
|
|
|
+ 'block_size', 'block_m', 'block_n', 'block_k', 'block_v']:
|
|
|
|
|
+ pattern = rf'\b{var}\s*=\s*64\b'
|
|
|
|
|
+ replacement = f'{var} = 32'
|
|
|
|
|
+ new_c = re.sub(pattern, replacement, c)
|
|
|
|
|
+ if new_c != c:
|
|
|
|
|
+ changes.append(f'{var}=64->32')
|
|
|
|
|
+ c = new_c
|
|
|
|
|
+
|
|
|
|
|
+ # 3. Triton autotune 装饰器中的 configs 参数值
|
|
|
|
|
+ # 例如: configs=[..., 64, ...] 或 tl.constexpr = 64
|
|
|
|
|
+ # 只替换 tl.constexpr = 64 的情况
|
|
|
|
|
+ pattern = r'tl\.constexpr\s*=\s*64\b'
|
|
|
|
|
+ new_c = re.sub(pattern, 'tl.constexpr = 32', c)
|
|
|
if new_c != c:
|
|
if new_c != c:
|
|
|
- changes.append('=64 -> =32')
|
|
|
|
|
|
|
+ changes.append('tl.constexpr 64->32')
|
|
|
c = new_c
|
|
c = new_c
|
|
|
|
|
|
|
|
- # 3. tl.constexpr = 128/256 → = 64(进一步降低大值)
|
|
|
|
|
- def _r_large(m):
|
|
|
|
|
- val = int(m.group(1))
|
|
|
|
|
- return f'tl.constexpr = {val // 2}'
|
|
|
|
|
- new_c = re.sub(r'tl\.constexpr\s*=\s*(128|256)\b', _r_large, c)
|
|
|
|
|
|
|
+ # 4. num_stages 降低(减少流水线阶段,进一步降低共享内存)
|
|
|
|
|
+ pattern = r'num_stages\s*=\s*([3-9]|[1-9]\d+)'
|
|
|
|
|
+ new_c = re.sub(pattern, 'num_stages=1', c)
|
|
|
if new_c != c:
|
|
if new_c != c:
|
|
|
- changes.append('constexpr 128/256 halved')
|
|
|
|
|
|
|
+ changes.append('num_stages->1')
|
|
|
c = new_c
|
|
c = new_c
|
|
|
|
|
|
|
|
if c != original:
|
|
if c != original:
|
|
@@ -294,23 +344,23 @@ def _patch_fla_shared_memory():
|
|
|
_remote_log(f" Warning: failed to patch {fpath}: {e}")
|
|
_remote_log(f" Warning: failed to patch {fpath}: {e}")
|
|
|
continue
|
|
continue
|
|
|
|
|
|
|
|
- # 清理 __pycache__,确保下次 import 读新源码
|
|
|
|
|
|
|
+ # 清理 __pycache__
|
|
|
cache_count = 0
|
|
cache_count = 0
|
|
|
for root, dirs, files in os.walk(fla_base):
|
|
for root, dirs, files in os.walk(fla_base):
|
|
|
if '__pycache__' in dirs:
|
|
if '__pycache__' in dirs:
|
|
|
shutil.rmtree(os.path.join(root, '__pycache__'), ignore_errors=True)
|
|
shutil.rmtree(os.path.join(root, '__pycache__'), ignore_errors=True)
|
|
|
cache_count += 1
|
|
cache_count += 1
|
|
|
|
|
|
|
|
- # 清除已缓存的 fla 模块,强制重新导入
|
|
|
|
|
|
|
+ # 清除已缓存的 fla 模块
|
|
|
to_remove = [k for k in sys.modules if k.startswith('fla')]
|
|
to_remove = [k for k in sys.modules if k.startswith('fla')]
|
|
|
for k in to_remove:
|
|
for k in to_remove:
|
|
|
del sys.modules[k]
|
|
del sys.modules[k]
|
|
|
|
|
|
|
|
- # 写入标记文件,下次运行时跳过(幂等)
|
|
|
|
|
|
|
+ # 写入标记文件(幂等),包含版本号 v2
|
|
|
with open(marker_path, 'w') as f:
|
|
with open(marker_path, 'w') as f:
|
|
|
- f.write(f"patched at {datetime.now(timezone.utc).isoformat()}\n")
|
|
|
|
|
|
|
+ f.write(f"v2 patched at {datetime.now(timezone.utc).isoformat()}\n")
|
|
|
|
|
|
|
|
- _remote_log(f"fla shared memory patch done: {len(patched_files)} files patched, "
|
|
|
|
|
|
|
+ _remote_log(f"fla shared memory patch done: {len(patched_files)} files, "
|
|
|
f"{cache_count} caches cleared, {len(to_remove)} modules evicted")
|
|
f"{cache_count} caches cleared, {len(to_remove)} modules evicted")
|
|
|
for pf in patched_files:
|
|
for pf in patched_files:
|
|
|
_remote_log(f" patched: {pf}")
|
|
_remote_log(f" patched: {pf}")
|