Daankular commited on
Commit
6e5a789
·
verified ·
1 Parent(s): 56aab81

Upload scripts/patch_unirig_flash.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. scripts/patch_unirig_flash.py +39 -0
scripts/patch_unirig_flash.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Patch UniRig to remove flash_attn hard dependency
2
+ # Replaces flash_attn.MHA with nn.MultiheadAttention (equivalent, already commented in source)
3
+
4
+ path = '/root/UniRig/src/model/unirig_skin.py'
5
+ with open(path) as f:
6
+ src = f.read()
7
+
8
+ # Remove flash_attn import
9
+ src = src.replace('from flash_attn.modules.mha import MHA\n', '')
10
+
11
+ # Replace MHA init
12
+ src = src.replace(
13
+ ' self.attention = MHA(embed_dim=feat_dim, num_heads=num_heads, cross_attn=True)',
14
+ ' self.attention = nn.MultiheadAttention(embed_dim=feat_dim, num_heads=num_heads, batch_first=True)'
15
+ )
16
+
17
+ # Replace MHA forward: (q, x_kv=kv) -> nn.MHA style (q, k, v) returning (out, weights)
18
+ src = src.replace(
19
+ ' attn_output = self.attention(q, x_kv=kv)',
20
+ ' attn_output, _ = self.attention(q, kv, kv)'
21
+ )
22
+
23
+ with open(path, 'w') as f:
24
+ f.write(src)
25
+ print('unirig_skin.py patched OK')
26
+
27
+ # Patch PTv3Object.py: flash_attn is try/except guarded but assert forces it on
28
+ path2 = '/root/UniRig/src/model/pointcept/models/PTv3Object.py'
29
+ with open(path2) as f:
30
+ src2 = f.read()
31
+
32
+ src2 = src2.replace(
33
+ 'assert flash_attn is not None, "Make sure flash_attn is installed."',
34
+ 'pass # flash_attn optional, use standard attention fallback'
35
+ )
36
+
37
+ with open(path2, 'w') as f:
38
+ f.write(src2)
39
+ print('PTv3Object.py patched OK')