Upload scripts/patch_unirig_restore_mha.py with huggingface_hub
Browse files
scripts/patch_unirig_restore_mha.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Restore unirig_skin.py to use flash_attn MHA (will be served by our shim in run.py)
|
| 2 |
+
# Our earlier patch replaced flash_attn MHA with nn.MultiheadAttention which breaks
|
| 3 |
+
# checkpoint loading because weight names don't match
|
| 4 |
+
|
| 5 |
+
path = '/root/UniRig/src/model/unirig_skin.py'
|
| 6 |
+
with open(path) as f:
|
| 7 |
+
src = f.read()
|
| 8 |
+
|
| 9 |
+
# Restore flash_attn import (remove it if replaced, add it back)
|
| 10 |
+
if 'from flash_attn.modules.mha import MHA' not in src:
|
| 11 |
+
# Add after the last import line before class definitions
|
| 12 |
+
src = src.replace(
|
| 13 |
+
'import torch_scatter\n',
|
| 14 |
+
'import torch_scatter\nfrom flash_attn.modules.mha import MHA\n'
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
# Restore original MHA init (undo nn.MultiheadAttention replacement)
|
| 18 |
+
src = src.replace(
|
| 19 |
+
' self.attention = nn.MultiheadAttention(embed_dim=feat_dim, num_heads=num_heads, batch_first=True)',
|
| 20 |
+
' self.attention = MHA(embed_dim=feat_dim, num_heads=num_heads, cross_attn=True)'
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
# Restore original MHA forward call
|
| 24 |
+
src = src.replace(
|
| 25 |
+
' attn_output, _ = self.attention(q, kv, kv)',
|
| 26 |
+
' attn_output = self.attention(q, x_kv=kv)'
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
with open(path, 'w') as f:
|
| 30 |
+
f.write(src)
|
| 31 |
+
|
| 32 |
+
print('unirig_skin.py restored to use flash_attn MHA (shim will serve it)')
|
| 33 |
+
|
| 34 |
+
# Verify
|
| 35 |
+
import subprocess
|
| 36 |
+
r = subprocess.run(['grep', '-n', 'MHA\|flash_attn\|Wq\|in_proj', path],
|
| 37 |
+
capture_output=True, text=True)
|
| 38 |
+
print(r.stdout)
|