Daankular commited on
Commit
7b75b7e
·
verified ·
1 Parent(s): 252ade7

Upload scripts/patch_unirig_restore_mha.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/patch_unirig_restore_mha.py +38 -0
scripts/patch_unirig_restore_mha.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Restore unirig_skin.py to use flash_attn MHA (will be served by our shim in run.py)
2
+ # Our earlier patch replaced flash_attn MHA with nn.MultiheadAttention which breaks
3
+ # checkpoint loading because weight names don't match
4
+
5
+ path = '/root/UniRig/src/model/unirig_skin.py'
6
+ with open(path) as f:
7
+ src = f.read()
8
+
9
+ # Restore flash_attn import (remove it if replaced, add it back)
10
+ if 'from flash_attn.modules.mha import MHA' not in src:
11
+ # Add after the last import line before class definitions
12
+ src = src.replace(
13
+ 'import torch_scatter\n',
14
+ 'import torch_scatter\nfrom flash_attn.modules.mha import MHA\n'
15
+ )
16
+
17
+ # Restore original MHA init (undo nn.MultiheadAttention replacement)
18
+ src = src.replace(
19
+ ' self.attention = nn.MultiheadAttention(embed_dim=feat_dim, num_heads=num_heads, batch_first=True)',
20
+ ' self.attention = MHA(embed_dim=feat_dim, num_heads=num_heads, cross_attn=True)'
21
+ )
22
+
23
+ # Restore original MHA forward call
24
+ src = src.replace(
25
+ ' attn_output, _ = self.attention(q, kv, kv)',
26
+ ' attn_output = self.attention(q, x_kv=kv)'
27
+ )
28
+
29
+ with open(path, 'w') as f:
30
+ f.write(src)
31
+
32
+ print('unirig_skin.py restored to use flash_attn MHA (shim will serve it)')
33
+
34
+ # Verify
35
+ import subprocess
36
+ r = subprocess.run(['grep', '-n', 'MHA\|flash_attn\|Wq\|in_proj', path],
37
+ capture_output=True, text=True)
38
+ print(r.stdout)