hqfang commited on
Commit
559eb2e
·
verified ·
1 Parent(s): e4c5078

Remove unused action expert modules

Browse files
config.json CHANGED
@@ -3,7 +3,6 @@
3
  "action_expert_config": {
4
  "attn_dropout": 0.0,
5
  "causal_attn": false,
6
- "compile": "blocks",
7
  "context_layer_norm": true,
8
  "dropout": 0.0,
9
  "ffn_multiple_of": 256,
@@ -15,7 +14,6 @@
15
  "qk_norm": true,
16
  "qk_norm_eps": 1e-06,
17
  "rope": true,
18
- "rope_on_cross_attention": true,
19
  "timestep_embed_dim": 256
20
  },
21
  "action_expert_depth_gate": false,
 
3
  "action_expert_config": {
4
  "attn_dropout": 0.0,
5
  "causal_attn": false,
 
6
  "context_layer_norm": true,
7
  "dropout": 0.0,
8
  "ffn_multiple_of": 256,
 
14
  "qk_norm": true,
15
  "qk_norm_eps": 1e-06,
16
  "rope": true,
 
17
  "timestep_embed_dim": 256
18
  },
19
  "action_expert_depth_gate": false,
model-00004-of-00005.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:df35c84b32b3460ad7bcca8ecbd3aefbe4d8caa281a030391d3586973965f340
3
- size 4998106920
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a81faa0f56099dd27590c1088e73b0a84e9fad71a322a90b89eb31dfd283d278
3
+ size 4877619536
model-00005-of-00005.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f89916b6e59a3924f93e0de2fe7d6113eb843f82cf96fb464370864642470c46
3
- size 2334605176
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b2eee6db4ad12f8b78fc3b0143aa4bd2510f477cdb2e736c355c41d26850afe
3
+ size 2282630240
model.safetensors.index.json CHANGED
@@ -1,14 +1,12 @@
1
  {
2
  "metadata": {
3
  "total_parameters": 5485309424,
4
- "total_size": 21941237952
5
  },
6
  "weight_map": {
7
  "lm_head.weight": "model-00005-of-00005.safetensors",
8
  "model.action_expert.action_embed.bias": "model-00004-of-00005.safetensors",
9
  "model.action_expert.action_embed.weight": "model-00004-of-00005.safetensors",
10
- "model.action_expert.blocks.0.cross_attn.kv_proj.bias": "model-00004-of-00005.safetensors",
11
- "model.action_expert.blocks.0.cross_attn.kv_proj.weight": "model-00004-of-00005.safetensors",
12
  "model.action_expert.blocks.0.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
13
  "model.action_expert.blocks.0.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
14
  "model.action_expert.blocks.0.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
@@ -25,8 +23,6 @@
25
  "model.action_expert.blocks.0.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
26
  "model.action_expert.blocks.0.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
27
  "model.action_expert.blocks.0.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
28
- "model.action_expert.blocks.1.cross_attn.kv_proj.bias": "model-00004-of-00005.safetensors",
29
- "model.action_expert.blocks.1.cross_attn.kv_proj.weight": "model-00004-of-00005.safetensors",
30
  "model.action_expert.blocks.1.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
31
  "model.action_expert.blocks.1.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
32
  "model.action_expert.blocks.1.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
@@ -43,8 +39,6 @@
43
  "model.action_expert.blocks.1.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
44
  "model.action_expert.blocks.1.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
45
  "model.action_expert.blocks.1.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
46
- "model.action_expert.blocks.10.cross_attn.kv_proj.bias": "model-00004-of-00005.safetensors",
47
- "model.action_expert.blocks.10.cross_attn.kv_proj.weight": "model-00004-of-00005.safetensors",
48
  "model.action_expert.blocks.10.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
49
  "model.action_expert.blocks.10.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
50
  "model.action_expert.blocks.10.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
@@ -61,8 +55,6 @@
61
  "model.action_expert.blocks.10.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
62
  "model.action_expert.blocks.10.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
63
  "model.action_expert.blocks.10.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
64
- "model.action_expert.blocks.11.cross_attn.kv_proj.bias": "model-00004-of-00005.safetensors",
65
- "model.action_expert.blocks.11.cross_attn.kv_proj.weight": "model-00004-of-00005.safetensors",
66
  "model.action_expert.blocks.11.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
67
  "model.action_expert.blocks.11.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
68
  "model.action_expert.blocks.11.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
@@ -79,8 +71,6 @@
79
  "model.action_expert.blocks.11.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
80
  "model.action_expert.blocks.11.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
81
  "model.action_expert.blocks.11.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
82
- "model.action_expert.blocks.12.cross_attn.kv_proj.bias": "model-00004-of-00005.safetensors",
83
- "model.action_expert.blocks.12.cross_attn.kv_proj.weight": "model-00004-of-00005.safetensors",
84
  "model.action_expert.blocks.12.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
85
  "model.action_expert.blocks.12.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
86
  "model.action_expert.blocks.12.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
@@ -97,8 +87,6 @@
97
  "model.action_expert.blocks.12.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
98
  "model.action_expert.blocks.12.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
99
  "model.action_expert.blocks.12.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
100
- "model.action_expert.blocks.13.cross_attn.kv_proj.bias": "model-00004-of-00005.safetensors",
101
- "model.action_expert.blocks.13.cross_attn.kv_proj.weight": "model-00004-of-00005.safetensors",
102
  "model.action_expert.blocks.13.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
103
  "model.action_expert.blocks.13.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
104
  "model.action_expert.blocks.13.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
@@ -115,8 +103,6 @@
115
  "model.action_expert.blocks.13.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
116
  "model.action_expert.blocks.13.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
117
  "model.action_expert.blocks.13.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
118
- "model.action_expert.blocks.14.cross_attn.kv_proj.bias": "model-00004-of-00005.safetensors",
119
- "model.action_expert.blocks.14.cross_attn.kv_proj.weight": "model-00004-of-00005.safetensors",
120
  "model.action_expert.blocks.14.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
121
  "model.action_expert.blocks.14.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
122
  "model.action_expert.blocks.14.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
@@ -133,8 +119,6 @@
133
  "model.action_expert.blocks.14.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
134
  "model.action_expert.blocks.14.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
135
  "model.action_expert.blocks.14.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
136
- "model.action_expert.blocks.15.cross_attn.kv_proj.bias": "model-00004-of-00005.safetensors",
137
- "model.action_expert.blocks.15.cross_attn.kv_proj.weight": "model-00004-of-00005.safetensors",
138
  "model.action_expert.blocks.15.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
139
  "model.action_expert.blocks.15.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
140
  "model.action_expert.blocks.15.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
@@ -151,8 +135,6 @@
151
  "model.action_expert.blocks.15.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
152
  "model.action_expert.blocks.15.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
153
  "model.action_expert.blocks.15.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
154
- "model.action_expert.blocks.16.cross_attn.kv_proj.bias": "model-00004-of-00005.safetensors",
155
- "model.action_expert.blocks.16.cross_attn.kv_proj.weight": "model-00004-of-00005.safetensors",
156
  "model.action_expert.blocks.16.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
157
  "model.action_expert.blocks.16.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
158
  "model.action_expert.blocks.16.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
@@ -169,8 +151,6 @@
169
  "model.action_expert.blocks.16.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
170
  "model.action_expert.blocks.16.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
171
  "model.action_expert.blocks.16.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
172
- "model.action_expert.blocks.17.cross_attn.kv_proj.bias": "model-00004-of-00005.safetensors",
173
- "model.action_expert.blocks.17.cross_attn.kv_proj.weight": "model-00004-of-00005.safetensors",
174
  "model.action_expert.blocks.17.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
175
  "model.action_expert.blocks.17.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
176
  "model.action_expert.blocks.17.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
@@ -187,8 +167,6 @@
187
  "model.action_expert.blocks.17.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
188
  "model.action_expert.blocks.17.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
189
  "model.action_expert.blocks.17.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
190
- "model.action_expert.blocks.18.cross_attn.kv_proj.bias": "model-00004-of-00005.safetensors",
191
- "model.action_expert.blocks.18.cross_attn.kv_proj.weight": "model-00004-of-00005.safetensors",
192
  "model.action_expert.blocks.18.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
193
  "model.action_expert.blocks.18.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
194
  "model.action_expert.blocks.18.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
@@ -205,8 +183,6 @@
205
  "model.action_expert.blocks.18.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
206
  "model.action_expert.blocks.18.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
207
  "model.action_expert.blocks.18.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
208
- "model.action_expert.blocks.19.cross_attn.kv_proj.bias": "model-00004-of-00005.safetensors",
209
- "model.action_expert.blocks.19.cross_attn.kv_proj.weight": "model-00004-of-00005.safetensors",
210
  "model.action_expert.blocks.19.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
211
  "model.action_expert.blocks.19.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
212
  "model.action_expert.blocks.19.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
@@ -223,8 +199,6 @@
223
  "model.action_expert.blocks.19.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
224
  "model.action_expert.blocks.19.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
225
  "model.action_expert.blocks.19.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
226
- "model.action_expert.blocks.2.cross_attn.kv_proj.bias": "model-00004-of-00005.safetensors",
227
- "model.action_expert.blocks.2.cross_attn.kv_proj.weight": "model-00004-of-00005.safetensors",
228
  "model.action_expert.blocks.2.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
229
  "model.action_expert.blocks.2.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
230
  "model.action_expert.blocks.2.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
@@ -241,8 +215,6 @@
241
  "model.action_expert.blocks.2.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
242
  "model.action_expert.blocks.2.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
243
  "model.action_expert.blocks.2.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
244
- "model.action_expert.blocks.20.cross_attn.kv_proj.bias": "model-00004-of-00005.safetensors",
245
- "model.action_expert.blocks.20.cross_attn.kv_proj.weight": "model-00004-of-00005.safetensors",
246
  "model.action_expert.blocks.20.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
247
  "model.action_expert.blocks.20.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
248
  "model.action_expert.blocks.20.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
@@ -259,8 +231,6 @@
259
  "model.action_expert.blocks.20.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
260
  "model.action_expert.blocks.20.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
261
  "model.action_expert.blocks.20.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
262
- "model.action_expert.blocks.21.cross_attn.kv_proj.bias": "model-00004-of-00005.safetensors",
263
- "model.action_expert.blocks.21.cross_attn.kv_proj.weight": "model-00004-of-00005.safetensors",
264
  "model.action_expert.blocks.21.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
265
  "model.action_expert.blocks.21.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
266
  "model.action_expert.blocks.21.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
@@ -277,8 +247,6 @@
277
  "model.action_expert.blocks.21.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
278
  "model.action_expert.blocks.21.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
279
  "model.action_expert.blocks.21.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
280
- "model.action_expert.blocks.22.cross_attn.kv_proj.bias": "model-00004-of-00005.safetensors",
281
- "model.action_expert.blocks.22.cross_attn.kv_proj.weight": "model-00004-of-00005.safetensors",
282
  "model.action_expert.blocks.22.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
283
  "model.action_expert.blocks.22.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
284
  "model.action_expert.blocks.22.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
@@ -295,8 +263,6 @@
295
  "model.action_expert.blocks.22.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
296
  "model.action_expert.blocks.22.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
297
  "model.action_expert.blocks.22.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
298
- "model.action_expert.blocks.23.cross_attn.kv_proj.bias": "model-00004-of-00005.safetensors",
299
- "model.action_expert.blocks.23.cross_attn.kv_proj.weight": "model-00004-of-00005.safetensors",
300
  "model.action_expert.blocks.23.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
301
  "model.action_expert.blocks.23.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
302
  "model.action_expert.blocks.23.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
@@ -313,8 +279,6 @@
313
  "model.action_expert.blocks.23.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
314
  "model.action_expert.blocks.23.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
315
  "model.action_expert.blocks.23.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
316
- "model.action_expert.blocks.24.cross_attn.kv_proj.bias": "model-00004-of-00005.safetensors",
317
- "model.action_expert.blocks.24.cross_attn.kv_proj.weight": "model-00004-of-00005.safetensors",
318
  "model.action_expert.blocks.24.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
319
  "model.action_expert.blocks.24.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
320
  "model.action_expert.blocks.24.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
@@ -331,8 +295,6 @@
331
  "model.action_expert.blocks.24.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
332
  "model.action_expert.blocks.24.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
333
  "model.action_expert.blocks.24.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
334
- "model.action_expert.blocks.25.cross_attn.kv_proj.bias": "model-00005-of-00005.safetensors",
335
- "model.action_expert.blocks.25.cross_attn.kv_proj.weight": "model-00005-of-00005.safetensors",
336
  "model.action_expert.blocks.25.cross_attn.out_proj.bias": "model-00005-of-00005.safetensors",
337
  "model.action_expert.blocks.25.cross_attn.out_proj.weight": "model-00005-of-00005.safetensors",
338
  "model.action_expert.blocks.25.cross_attn.q_proj.bias": "model-00005-of-00005.safetensors",
@@ -349,8 +311,6 @@
349
  "model.action_expert.blocks.25.self_attn.out_proj.weight": "model-00005-of-00005.safetensors",
350
  "model.action_expert.blocks.25.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
351
  "model.action_expert.blocks.25.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
352
- "model.action_expert.blocks.26.cross_attn.kv_proj.bias": "model-00005-of-00005.safetensors",
353
- "model.action_expert.blocks.26.cross_attn.kv_proj.weight": "model-00005-of-00005.safetensors",
354
  "model.action_expert.blocks.26.cross_attn.out_proj.bias": "model-00005-of-00005.safetensors",
355
  "model.action_expert.blocks.26.cross_attn.out_proj.weight": "model-00005-of-00005.safetensors",
356
  "model.action_expert.blocks.26.cross_attn.q_proj.bias": "model-00005-of-00005.safetensors",
@@ -367,8 +327,6 @@
367
  "model.action_expert.blocks.26.self_attn.out_proj.weight": "model-00005-of-00005.safetensors",
368
  "model.action_expert.blocks.26.self_attn.qkv.bias": "model-00005-of-00005.safetensors",
369
  "model.action_expert.blocks.26.self_attn.qkv.weight": "model-00005-of-00005.safetensors",
370
- "model.action_expert.blocks.27.cross_attn.kv_proj.bias": "model-00005-of-00005.safetensors",
371
- "model.action_expert.blocks.27.cross_attn.kv_proj.weight": "model-00005-of-00005.safetensors",
372
  "model.action_expert.blocks.27.cross_attn.out_proj.bias": "model-00005-of-00005.safetensors",
373
  "model.action_expert.blocks.27.cross_attn.out_proj.weight": "model-00005-of-00005.safetensors",
374
  "model.action_expert.blocks.27.cross_attn.q_proj.bias": "model-00005-of-00005.safetensors",
@@ -385,8 +343,6 @@
385
  "model.action_expert.blocks.27.self_attn.out_proj.weight": "model-00005-of-00005.safetensors",
386
  "model.action_expert.blocks.27.self_attn.qkv.bias": "model-00005-of-00005.safetensors",
387
  "model.action_expert.blocks.27.self_attn.qkv.weight": "model-00005-of-00005.safetensors",
388
- "model.action_expert.blocks.28.cross_attn.kv_proj.bias": "model-00005-of-00005.safetensors",
389
- "model.action_expert.blocks.28.cross_attn.kv_proj.weight": "model-00005-of-00005.safetensors",
390
  "model.action_expert.blocks.28.cross_attn.out_proj.bias": "model-00005-of-00005.safetensors",
391
  "model.action_expert.blocks.28.cross_attn.out_proj.weight": "model-00005-of-00005.safetensors",
392
  "model.action_expert.blocks.28.cross_attn.q_proj.bias": "model-00005-of-00005.safetensors",
@@ -403,8 +359,6 @@
403
  "model.action_expert.blocks.28.self_attn.out_proj.weight": "model-00005-of-00005.safetensors",
404
  "model.action_expert.blocks.28.self_attn.qkv.bias": "model-00005-of-00005.safetensors",
405
  "model.action_expert.blocks.28.self_attn.qkv.weight": "model-00005-of-00005.safetensors",
406
- "model.action_expert.blocks.29.cross_attn.kv_proj.bias": "model-00005-of-00005.safetensors",
407
- "model.action_expert.blocks.29.cross_attn.kv_proj.weight": "model-00005-of-00005.safetensors",
408
  "model.action_expert.blocks.29.cross_attn.out_proj.bias": "model-00005-of-00005.safetensors",
409
  "model.action_expert.blocks.29.cross_attn.out_proj.weight": "model-00005-of-00005.safetensors",
410
  "model.action_expert.blocks.29.cross_attn.q_proj.bias": "model-00005-of-00005.safetensors",
@@ -421,8 +375,6 @@
421
  "model.action_expert.blocks.29.self_attn.out_proj.weight": "model-00005-of-00005.safetensors",
422
  "model.action_expert.blocks.29.self_attn.qkv.bias": "model-00005-of-00005.safetensors",
423
  "model.action_expert.blocks.29.self_attn.qkv.weight": "model-00005-of-00005.safetensors",
424
- "model.action_expert.blocks.3.cross_attn.kv_proj.bias": "model-00004-of-00005.safetensors",
425
- "model.action_expert.blocks.3.cross_attn.kv_proj.weight": "model-00004-of-00005.safetensors",
426
  "model.action_expert.blocks.3.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
427
  "model.action_expert.blocks.3.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
428
  "model.action_expert.blocks.3.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
@@ -439,8 +391,6 @@
439
  "model.action_expert.blocks.3.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
440
  "model.action_expert.blocks.3.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
441
  "model.action_expert.blocks.3.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
442
- "model.action_expert.blocks.30.cross_attn.kv_proj.bias": "model-00005-of-00005.safetensors",
443
- "model.action_expert.blocks.30.cross_attn.kv_proj.weight": "model-00005-of-00005.safetensors",
444
  "model.action_expert.blocks.30.cross_attn.out_proj.bias": "model-00005-of-00005.safetensors",
445
  "model.action_expert.blocks.30.cross_attn.out_proj.weight": "model-00005-of-00005.safetensors",
446
  "model.action_expert.blocks.30.cross_attn.q_proj.bias": "model-00005-of-00005.safetensors",
@@ -457,8 +407,6 @@
457
  "model.action_expert.blocks.30.self_attn.out_proj.weight": "model-00005-of-00005.safetensors",
458
  "model.action_expert.blocks.30.self_attn.qkv.bias": "model-00005-of-00005.safetensors",
459
  "model.action_expert.blocks.30.self_attn.qkv.weight": "model-00005-of-00005.safetensors",
460
- "model.action_expert.blocks.31.cross_attn.kv_proj.bias": "model-00005-of-00005.safetensors",
461
- "model.action_expert.blocks.31.cross_attn.kv_proj.weight": "model-00005-of-00005.safetensors",
462
  "model.action_expert.blocks.31.cross_attn.out_proj.bias": "model-00005-of-00005.safetensors",
463
  "model.action_expert.blocks.31.cross_attn.out_proj.weight": "model-00005-of-00005.safetensors",
464
  "model.action_expert.blocks.31.cross_attn.q_proj.bias": "model-00005-of-00005.safetensors",
@@ -475,8 +423,6 @@
475
  "model.action_expert.blocks.31.self_attn.out_proj.weight": "model-00005-of-00005.safetensors",
476
  "model.action_expert.blocks.31.self_attn.qkv.bias": "model-00005-of-00005.safetensors",
477
  "model.action_expert.blocks.31.self_attn.qkv.weight": "model-00005-of-00005.safetensors",
478
- "model.action_expert.blocks.32.cross_attn.kv_proj.bias": "model-00005-of-00005.safetensors",
479
- "model.action_expert.blocks.32.cross_attn.kv_proj.weight": "model-00005-of-00005.safetensors",
480
  "model.action_expert.blocks.32.cross_attn.out_proj.bias": "model-00005-of-00005.safetensors",
481
  "model.action_expert.blocks.32.cross_attn.out_proj.weight": "model-00005-of-00005.safetensors",
482
  "model.action_expert.blocks.32.cross_attn.q_proj.bias": "model-00005-of-00005.safetensors",
@@ -493,8 +439,6 @@
493
  "model.action_expert.blocks.32.self_attn.out_proj.weight": "model-00005-of-00005.safetensors",
494
  "model.action_expert.blocks.32.self_attn.qkv.bias": "model-00005-of-00005.safetensors",
495
  "model.action_expert.blocks.32.self_attn.qkv.weight": "model-00005-of-00005.safetensors",
496
- "model.action_expert.blocks.33.cross_attn.kv_proj.bias": "model-00005-of-00005.safetensors",
497
- "model.action_expert.blocks.33.cross_attn.kv_proj.weight": "model-00005-of-00005.safetensors",
498
  "model.action_expert.blocks.33.cross_attn.out_proj.bias": "model-00005-of-00005.safetensors",
499
  "model.action_expert.blocks.33.cross_attn.out_proj.weight": "model-00005-of-00005.safetensors",
500
  "model.action_expert.blocks.33.cross_attn.q_proj.bias": "model-00005-of-00005.safetensors",
@@ -511,8 +455,6 @@
511
  "model.action_expert.blocks.33.self_attn.out_proj.weight": "model-00005-of-00005.safetensors",
512
  "model.action_expert.blocks.33.self_attn.qkv.bias": "model-00005-of-00005.safetensors",
513
  "model.action_expert.blocks.33.self_attn.qkv.weight": "model-00005-of-00005.safetensors",
514
- "model.action_expert.blocks.34.cross_attn.kv_proj.bias": "model-00005-of-00005.safetensors",
515
- "model.action_expert.blocks.34.cross_attn.kv_proj.weight": "model-00005-of-00005.safetensors",
516
  "model.action_expert.blocks.34.cross_attn.out_proj.bias": "model-00005-of-00005.safetensors",
517
  "model.action_expert.blocks.34.cross_attn.out_proj.weight": "model-00005-of-00005.safetensors",
518
  "model.action_expert.blocks.34.cross_attn.q_proj.bias": "model-00005-of-00005.safetensors",
@@ -529,8 +471,6 @@
529
  "model.action_expert.blocks.34.self_attn.out_proj.weight": "model-00005-of-00005.safetensors",
530
  "model.action_expert.blocks.34.self_attn.qkv.bias": "model-00005-of-00005.safetensors",
531
  "model.action_expert.blocks.34.self_attn.qkv.weight": "model-00005-of-00005.safetensors",
532
- "model.action_expert.blocks.35.cross_attn.kv_proj.bias": "model-00005-of-00005.safetensors",
533
- "model.action_expert.blocks.35.cross_attn.kv_proj.weight": "model-00005-of-00005.safetensors",
534
  "model.action_expert.blocks.35.cross_attn.out_proj.bias": "model-00005-of-00005.safetensors",
535
  "model.action_expert.blocks.35.cross_attn.out_proj.weight": "model-00005-of-00005.safetensors",
536
  "model.action_expert.blocks.35.cross_attn.q_proj.bias": "model-00005-of-00005.safetensors",
@@ -547,8 +487,6 @@
547
  "model.action_expert.blocks.35.self_attn.out_proj.weight": "model-00005-of-00005.safetensors",
548
  "model.action_expert.blocks.35.self_attn.qkv.bias": "model-00005-of-00005.safetensors",
549
  "model.action_expert.blocks.35.self_attn.qkv.weight": "model-00005-of-00005.safetensors",
550
- "model.action_expert.blocks.4.cross_attn.kv_proj.bias": "model-00004-of-00005.safetensors",
551
- "model.action_expert.blocks.4.cross_attn.kv_proj.weight": "model-00004-of-00005.safetensors",
552
  "model.action_expert.blocks.4.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
553
  "model.action_expert.blocks.4.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
554
  "model.action_expert.blocks.4.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
@@ -565,8 +503,6 @@
565
  "model.action_expert.blocks.4.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
566
  "model.action_expert.blocks.4.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
567
  "model.action_expert.blocks.4.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
568
- "model.action_expert.blocks.5.cross_attn.kv_proj.bias": "model-00004-of-00005.safetensors",
569
- "model.action_expert.blocks.5.cross_attn.kv_proj.weight": "model-00004-of-00005.safetensors",
570
  "model.action_expert.blocks.5.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
571
  "model.action_expert.blocks.5.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
572
  "model.action_expert.blocks.5.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
@@ -583,8 +519,6 @@
583
  "model.action_expert.blocks.5.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
584
  "model.action_expert.blocks.5.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
585
  "model.action_expert.blocks.5.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
586
- "model.action_expert.blocks.6.cross_attn.kv_proj.bias": "model-00004-of-00005.safetensors",
587
- "model.action_expert.blocks.6.cross_attn.kv_proj.weight": "model-00004-of-00005.safetensors",
588
  "model.action_expert.blocks.6.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
589
  "model.action_expert.blocks.6.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
590
  "model.action_expert.blocks.6.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
@@ -601,8 +535,6 @@
601
  "model.action_expert.blocks.6.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
602
  "model.action_expert.blocks.6.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
603
  "model.action_expert.blocks.6.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
604
- "model.action_expert.blocks.7.cross_attn.kv_proj.bias": "model-00004-of-00005.safetensors",
605
- "model.action_expert.blocks.7.cross_attn.kv_proj.weight": "model-00004-of-00005.safetensors",
606
  "model.action_expert.blocks.7.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
607
  "model.action_expert.blocks.7.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
608
  "model.action_expert.blocks.7.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
@@ -619,8 +551,6 @@
619
  "model.action_expert.blocks.7.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
620
  "model.action_expert.blocks.7.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
621
  "model.action_expert.blocks.7.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
622
- "model.action_expert.blocks.8.cross_attn.kv_proj.bias": "model-00004-of-00005.safetensors",
623
- "model.action_expert.blocks.8.cross_attn.kv_proj.weight": "model-00004-of-00005.safetensors",
624
  "model.action_expert.blocks.8.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
625
  "model.action_expert.blocks.8.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
626
  "model.action_expert.blocks.8.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
@@ -637,8 +567,6 @@
637
  "model.action_expert.blocks.8.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
638
  "model.action_expert.blocks.8.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
639
  "model.action_expert.blocks.8.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
640
- "model.action_expert.blocks.9.cross_attn.kv_proj.bias": "model-00004-of-00005.safetensors",
641
- "model.action_expert.blocks.9.cross_attn.kv_proj.weight": "model-00004-of-00005.safetensors",
642
  "model.action_expert.blocks.9.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
643
  "model.action_expert.blocks.9.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
644
  "model.action_expert.blocks.9.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
@@ -661,8 +589,6 @@
661
  "model.action_expert.final_layer.linear.weight": "model-00005-of-00005.safetensors",
662
  "model.action_expert.final_layer.modulation.linear.bias": "model-00005-of-00005.safetensors",
663
  "model.action_expert.final_layer.modulation.linear.weight": "model-00005-of-00005.safetensors",
664
- "model.action_expert.state_encoder.bias": "model-00004-of-00005.safetensors",
665
- "model.action_expert.state_encoder.weight": "model-00004-of-00005.safetensors",
666
  "model.action_expert.time_embed.1.bias": "model-00004-of-00005.safetensors",
667
  "model.action_expert.time_embed.1.weight": "model-00004-of-00005.safetensors",
668
  "model.action_expert.time_embed.3.bias": "model-00004-of-00005.safetensors",
 
1
  {
2
  "metadata": {
3
  "total_parameters": 5485309424,
4
+ "total_size": 21768785088
5
  },
6
  "weight_map": {
7
  "lm_head.weight": "model-00005-of-00005.safetensors",
8
  "model.action_expert.action_embed.bias": "model-00004-of-00005.safetensors",
9
  "model.action_expert.action_embed.weight": "model-00004-of-00005.safetensors",
 
 
10
  "model.action_expert.blocks.0.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
11
  "model.action_expert.blocks.0.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
12
  "model.action_expert.blocks.0.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
 
23
  "model.action_expert.blocks.0.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
24
  "model.action_expert.blocks.0.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
25
  "model.action_expert.blocks.0.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
 
 
26
  "model.action_expert.blocks.1.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
27
  "model.action_expert.blocks.1.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
28
  "model.action_expert.blocks.1.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
 
39
  "model.action_expert.blocks.1.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
40
  "model.action_expert.blocks.1.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
41
  "model.action_expert.blocks.1.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
 
 
42
  "model.action_expert.blocks.10.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
43
  "model.action_expert.blocks.10.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
44
  "model.action_expert.blocks.10.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
 
55
  "model.action_expert.blocks.10.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
56
  "model.action_expert.blocks.10.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
57
  "model.action_expert.blocks.10.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
 
 
58
  "model.action_expert.blocks.11.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
59
  "model.action_expert.blocks.11.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
60
  "model.action_expert.blocks.11.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
 
71
  "model.action_expert.blocks.11.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
72
  "model.action_expert.blocks.11.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
73
  "model.action_expert.blocks.11.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
 
 
74
  "model.action_expert.blocks.12.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
75
  "model.action_expert.blocks.12.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
76
  "model.action_expert.blocks.12.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
 
87
  "model.action_expert.blocks.12.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
88
  "model.action_expert.blocks.12.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
89
  "model.action_expert.blocks.12.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
 
 
90
  "model.action_expert.blocks.13.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
91
  "model.action_expert.blocks.13.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
92
  "model.action_expert.blocks.13.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
 
103
  "model.action_expert.blocks.13.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
104
  "model.action_expert.blocks.13.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
105
  "model.action_expert.blocks.13.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
 
 
106
  "model.action_expert.blocks.14.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
107
  "model.action_expert.blocks.14.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
108
  "model.action_expert.blocks.14.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
 
119
  "model.action_expert.blocks.14.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
120
  "model.action_expert.blocks.14.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
121
  "model.action_expert.blocks.14.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
 
 
122
  "model.action_expert.blocks.15.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
123
  "model.action_expert.blocks.15.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
124
  "model.action_expert.blocks.15.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
 
135
  "model.action_expert.blocks.15.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
136
  "model.action_expert.blocks.15.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
137
  "model.action_expert.blocks.15.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
 
 
138
  "model.action_expert.blocks.16.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
139
  "model.action_expert.blocks.16.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
140
  "model.action_expert.blocks.16.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
 
151
  "model.action_expert.blocks.16.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
152
  "model.action_expert.blocks.16.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
153
  "model.action_expert.blocks.16.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
 
 
154
  "model.action_expert.blocks.17.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
155
  "model.action_expert.blocks.17.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
156
  "model.action_expert.blocks.17.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
 
167
  "model.action_expert.blocks.17.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
168
  "model.action_expert.blocks.17.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
169
  "model.action_expert.blocks.17.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
 
 
170
  "model.action_expert.blocks.18.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
171
  "model.action_expert.blocks.18.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
172
  "model.action_expert.blocks.18.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
 
183
  "model.action_expert.blocks.18.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
184
  "model.action_expert.blocks.18.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
185
  "model.action_expert.blocks.18.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
 
 
186
  "model.action_expert.blocks.19.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
187
  "model.action_expert.blocks.19.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
188
  "model.action_expert.blocks.19.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
 
199
  "model.action_expert.blocks.19.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
200
  "model.action_expert.blocks.19.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
201
  "model.action_expert.blocks.19.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
 
 
202
  "model.action_expert.blocks.2.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
203
  "model.action_expert.blocks.2.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
204
  "model.action_expert.blocks.2.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
 
215
  "model.action_expert.blocks.2.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
216
  "model.action_expert.blocks.2.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
217
  "model.action_expert.blocks.2.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
 
 
218
  "model.action_expert.blocks.20.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
219
  "model.action_expert.blocks.20.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
220
  "model.action_expert.blocks.20.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
 
231
  "model.action_expert.blocks.20.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
232
  "model.action_expert.blocks.20.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
233
  "model.action_expert.blocks.20.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
 
 
234
  "model.action_expert.blocks.21.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
235
  "model.action_expert.blocks.21.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
236
  "model.action_expert.blocks.21.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
 
247
  "model.action_expert.blocks.21.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
248
  "model.action_expert.blocks.21.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
249
  "model.action_expert.blocks.21.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
 
 
250
  "model.action_expert.blocks.22.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
251
  "model.action_expert.blocks.22.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
252
  "model.action_expert.blocks.22.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
 
263
  "model.action_expert.blocks.22.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
264
  "model.action_expert.blocks.22.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
265
  "model.action_expert.blocks.22.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
 
 
266
  "model.action_expert.blocks.23.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
267
  "model.action_expert.blocks.23.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
268
  "model.action_expert.blocks.23.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
 
279
  "model.action_expert.blocks.23.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
280
  "model.action_expert.blocks.23.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
281
  "model.action_expert.blocks.23.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
 
 
282
  "model.action_expert.blocks.24.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
283
  "model.action_expert.blocks.24.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
284
  "model.action_expert.blocks.24.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
 
295
  "model.action_expert.blocks.24.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
296
  "model.action_expert.blocks.24.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
297
  "model.action_expert.blocks.24.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
 
 
298
  "model.action_expert.blocks.25.cross_attn.out_proj.bias": "model-00005-of-00005.safetensors",
299
  "model.action_expert.blocks.25.cross_attn.out_proj.weight": "model-00005-of-00005.safetensors",
300
  "model.action_expert.blocks.25.cross_attn.q_proj.bias": "model-00005-of-00005.safetensors",
 
311
  "model.action_expert.blocks.25.self_attn.out_proj.weight": "model-00005-of-00005.safetensors",
312
  "model.action_expert.blocks.25.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
313
  "model.action_expert.blocks.25.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
 
 
314
  "model.action_expert.blocks.26.cross_attn.out_proj.bias": "model-00005-of-00005.safetensors",
315
  "model.action_expert.blocks.26.cross_attn.out_proj.weight": "model-00005-of-00005.safetensors",
316
  "model.action_expert.blocks.26.cross_attn.q_proj.bias": "model-00005-of-00005.safetensors",
 
327
  "model.action_expert.blocks.26.self_attn.out_proj.weight": "model-00005-of-00005.safetensors",
328
  "model.action_expert.blocks.26.self_attn.qkv.bias": "model-00005-of-00005.safetensors",
329
  "model.action_expert.blocks.26.self_attn.qkv.weight": "model-00005-of-00005.safetensors",
 
 
330
  "model.action_expert.blocks.27.cross_attn.out_proj.bias": "model-00005-of-00005.safetensors",
331
  "model.action_expert.blocks.27.cross_attn.out_proj.weight": "model-00005-of-00005.safetensors",
332
  "model.action_expert.blocks.27.cross_attn.q_proj.bias": "model-00005-of-00005.safetensors",
 
343
  "model.action_expert.blocks.27.self_attn.out_proj.weight": "model-00005-of-00005.safetensors",
344
  "model.action_expert.blocks.27.self_attn.qkv.bias": "model-00005-of-00005.safetensors",
345
  "model.action_expert.blocks.27.self_attn.qkv.weight": "model-00005-of-00005.safetensors",
 
 
346
  "model.action_expert.blocks.28.cross_attn.out_proj.bias": "model-00005-of-00005.safetensors",
347
  "model.action_expert.blocks.28.cross_attn.out_proj.weight": "model-00005-of-00005.safetensors",
348
  "model.action_expert.blocks.28.cross_attn.q_proj.bias": "model-00005-of-00005.safetensors",
 
359
  "model.action_expert.blocks.28.self_attn.out_proj.weight": "model-00005-of-00005.safetensors",
360
  "model.action_expert.blocks.28.self_attn.qkv.bias": "model-00005-of-00005.safetensors",
361
  "model.action_expert.blocks.28.self_attn.qkv.weight": "model-00005-of-00005.safetensors",
 
 
362
  "model.action_expert.blocks.29.cross_attn.out_proj.bias": "model-00005-of-00005.safetensors",
363
  "model.action_expert.blocks.29.cross_attn.out_proj.weight": "model-00005-of-00005.safetensors",
364
  "model.action_expert.blocks.29.cross_attn.q_proj.bias": "model-00005-of-00005.safetensors",
 
375
  "model.action_expert.blocks.29.self_attn.out_proj.weight": "model-00005-of-00005.safetensors",
376
  "model.action_expert.blocks.29.self_attn.qkv.bias": "model-00005-of-00005.safetensors",
377
  "model.action_expert.blocks.29.self_attn.qkv.weight": "model-00005-of-00005.safetensors",
 
 
378
  "model.action_expert.blocks.3.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
379
  "model.action_expert.blocks.3.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
380
  "model.action_expert.blocks.3.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
 
391
  "model.action_expert.blocks.3.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
392
  "model.action_expert.blocks.3.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
393
  "model.action_expert.blocks.3.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
 
 
394
  "model.action_expert.blocks.30.cross_attn.out_proj.bias": "model-00005-of-00005.safetensors",
395
  "model.action_expert.blocks.30.cross_attn.out_proj.weight": "model-00005-of-00005.safetensors",
396
  "model.action_expert.blocks.30.cross_attn.q_proj.bias": "model-00005-of-00005.safetensors",
 
407
  "model.action_expert.blocks.30.self_attn.out_proj.weight": "model-00005-of-00005.safetensors",
408
  "model.action_expert.blocks.30.self_attn.qkv.bias": "model-00005-of-00005.safetensors",
409
  "model.action_expert.blocks.30.self_attn.qkv.weight": "model-00005-of-00005.safetensors",
 
 
410
  "model.action_expert.blocks.31.cross_attn.out_proj.bias": "model-00005-of-00005.safetensors",
411
  "model.action_expert.blocks.31.cross_attn.out_proj.weight": "model-00005-of-00005.safetensors",
412
  "model.action_expert.blocks.31.cross_attn.q_proj.bias": "model-00005-of-00005.safetensors",
 
423
  "model.action_expert.blocks.31.self_attn.out_proj.weight": "model-00005-of-00005.safetensors",
424
  "model.action_expert.blocks.31.self_attn.qkv.bias": "model-00005-of-00005.safetensors",
425
  "model.action_expert.blocks.31.self_attn.qkv.weight": "model-00005-of-00005.safetensors",
 
 
426
  "model.action_expert.blocks.32.cross_attn.out_proj.bias": "model-00005-of-00005.safetensors",
427
  "model.action_expert.blocks.32.cross_attn.out_proj.weight": "model-00005-of-00005.safetensors",
428
  "model.action_expert.blocks.32.cross_attn.q_proj.bias": "model-00005-of-00005.safetensors",
 
439
  "model.action_expert.blocks.32.self_attn.out_proj.weight": "model-00005-of-00005.safetensors",
440
  "model.action_expert.blocks.32.self_attn.qkv.bias": "model-00005-of-00005.safetensors",
441
  "model.action_expert.blocks.32.self_attn.qkv.weight": "model-00005-of-00005.safetensors",
 
 
442
  "model.action_expert.blocks.33.cross_attn.out_proj.bias": "model-00005-of-00005.safetensors",
443
  "model.action_expert.blocks.33.cross_attn.out_proj.weight": "model-00005-of-00005.safetensors",
444
  "model.action_expert.blocks.33.cross_attn.q_proj.bias": "model-00005-of-00005.safetensors",
 
455
  "model.action_expert.blocks.33.self_attn.out_proj.weight": "model-00005-of-00005.safetensors",
456
  "model.action_expert.blocks.33.self_attn.qkv.bias": "model-00005-of-00005.safetensors",
457
  "model.action_expert.blocks.33.self_attn.qkv.weight": "model-00005-of-00005.safetensors",
 
 
458
  "model.action_expert.blocks.34.cross_attn.out_proj.bias": "model-00005-of-00005.safetensors",
459
  "model.action_expert.blocks.34.cross_attn.out_proj.weight": "model-00005-of-00005.safetensors",
460
  "model.action_expert.blocks.34.cross_attn.q_proj.bias": "model-00005-of-00005.safetensors",
 
471
  "model.action_expert.blocks.34.self_attn.out_proj.weight": "model-00005-of-00005.safetensors",
472
  "model.action_expert.blocks.34.self_attn.qkv.bias": "model-00005-of-00005.safetensors",
473
  "model.action_expert.blocks.34.self_attn.qkv.weight": "model-00005-of-00005.safetensors",
 
 
474
  "model.action_expert.blocks.35.cross_attn.out_proj.bias": "model-00005-of-00005.safetensors",
475
  "model.action_expert.blocks.35.cross_attn.out_proj.weight": "model-00005-of-00005.safetensors",
476
  "model.action_expert.blocks.35.cross_attn.q_proj.bias": "model-00005-of-00005.safetensors",
 
487
  "model.action_expert.blocks.35.self_attn.out_proj.weight": "model-00005-of-00005.safetensors",
488
  "model.action_expert.blocks.35.self_attn.qkv.bias": "model-00005-of-00005.safetensors",
489
  "model.action_expert.blocks.35.self_attn.qkv.weight": "model-00005-of-00005.safetensors",
 
 
490
  "model.action_expert.blocks.4.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
491
  "model.action_expert.blocks.4.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
492
  "model.action_expert.blocks.4.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
 
503
  "model.action_expert.blocks.4.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
504
  "model.action_expert.blocks.4.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
505
  "model.action_expert.blocks.4.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
 
 
506
  "model.action_expert.blocks.5.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
507
  "model.action_expert.blocks.5.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
508
  "model.action_expert.blocks.5.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
 
519
  "model.action_expert.blocks.5.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
520
  "model.action_expert.blocks.5.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
521
  "model.action_expert.blocks.5.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
 
 
522
  "model.action_expert.blocks.6.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
523
  "model.action_expert.blocks.6.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
524
  "model.action_expert.blocks.6.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
 
535
  "model.action_expert.blocks.6.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
536
  "model.action_expert.blocks.6.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
537
  "model.action_expert.blocks.6.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
 
 
538
  "model.action_expert.blocks.7.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
539
  "model.action_expert.blocks.7.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
540
  "model.action_expert.blocks.7.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
 
551
  "model.action_expert.blocks.7.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
552
  "model.action_expert.blocks.7.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
553
  "model.action_expert.blocks.7.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
 
 
554
  "model.action_expert.blocks.8.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
555
  "model.action_expert.blocks.8.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
556
  "model.action_expert.blocks.8.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
 
567
  "model.action_expert.blocks.8.self_attn.out_proj.weight": "model-00004-of-00005.safetensors",
568
  "model.action_expert.blocks.8.self_attn.qkv.bias": "model-00004-of-00005.safetensors",
569
  "model.action_expert.blocks.8.self_attn.qkv.weight": "model-00004-of-00005.safetensors",
 
 
570
  "model.action_expert.blocks.9.cross_attn.out_proj.bias": "model-00004-of-00005.safetensors",
571
  "model.action_expert.blocks.9.cross_attn.out_proj.weight": "model-00004-of-00005.safetensors",
572
  "model.action_expert.blocks.9.cross_attn.q_proj.bias": "model-00004-of-00005.safetensors",
 
589
  "model.action_expert.final_layer.linear.weight": "model-00005-of-00005.safetensors",
590
  "model.action_expert.final_layer.modulation.linear.bias": "model-00005-of-00005.safetensors",
591
  "model.action_expert.final_layer.modulation.linear.weight": "model-00005-of-00005.safetensors",
 
 
592
  "model.action_expert.time_embed.1.bias": "model-00004-of-00005.safetensors",
593
  "model.action_expert.time_embed.1.weight": "model-00004-of-00005.safetensors",
594
  "model.action_expert.time_embed.3.bias": "model-00004-of-00005.safetensors",
modeling_molmoact2.py CHANGED
@@ -315,17 +315,9 @@ class ActionExpertCrossAttention(nn.Module):
315
  ActionExpertRMSNorm(self.head_dim, eps=qk_norm_eps) if qk_norm else None
316
  )
317
  self.q_proj = nn.Linear(hidden_size, hidden_size)
318
- self.kv_proj = nn.Linear(hidden_size, hidden_size * 2)
319
  self.out_proj = nn.Linear(hidden_size, hidden_size)
320
  self.out_drop = nn.Dropout(proj_dropout)
321
 
322
- def _apply_qk_norm(
323
- self, q: torch.Tensor, k: torch.Tensor
324
- ) -> Tuple[torch.Tensor, torch.Tensor]:
325
- if self.q_norm is None or self.k_norm is None:
326
- return q, k
327
- return self.q_norm(q), self.k_norm(k)
328
-
329
  def _as_heads(self, x: torch.Tensor) -> torch.Tensor:
330
  if x.dim() == 4:
331
  if x.shape[2] == self.num_heads:
@@ -361,38 +353,18 @@ class ActionExpertCrossAttention(nn.Module):
361
  self,
362
  x: torch.Tensor,
363
  *,
364
- kv: Optional[torch.Tensor] = None,
365
- kv_k: Optional[torch.Tensor] = None,
366
- kv_v: Optional[torch.Tensor] = None,
367
  attn_mask: Optional[torch.Tensor] = None,
368
  ) -> torch.Tensor:
369
- if (kv_k is None) != (kv_v is None):
370
- raise ValueError("kv_k and kv_v must both be provided or both be None.")
371
- if kv is not None and kv_k is not None:
372
- raise ValueError("Provide either kv or kv_k/kv_v, not both.")
373
  bsz, tgt_len, _ = x.shape
374
  q = self.q_proj(x).view(bsz, tgt_len, self.num_heads, self.head_dim)
375
- if kv_k is not None and kv_v is not None:
376
- k = self._as_heads(kv_k)
377
- v = self._as_heads(kv_v)
378
- k_pre_normed = True
379
- else:
380
- if kv is None:
381
- raise ValueError("cross-attention requires kv or kv_k/kv_v.")
382
- src_len = kv.shape[1]
383
- kv_proj = self.kv_proj(kv).view(
384
- bsz, src_len, 2, self.num_heads, self.head_dim
385
- )
386
- k = kv_proj[:, :, 0]
387
- v = kv_proj[:, :, 1]
388
- k_pre_normed = False
389
  q = q.transpose(1, 2)
390
  k = k.transpose(1, 2)
391
- if k_pre_normed:
392
- if self.q_norm is not None:
393
- q = self.q_norm(q)
394
- else:
395
- q, k = self._apply_qk_norm(q, k)
396
  q = q.transpose(1, 2)
397
  k = k.transpose(1, 2)
398
  out = self._attention(q, k, v, attn_mask=attn_mask)
@@ -592,12 +564,6 @@ class ActionExpert(nn.Module):
592
  self.action_embed = nn.Linear(
593
  config.max_action_dim, config.hidden_size, device=device
594
  )
595
- self.state_encoder = nn.Linear(
596
- config.hidden_size, config.hidden_size, device=device
597
- )
598
- self.state_norm = ActionExpertRMSNorm(
599
- config.hidden_size, eps=1e-6, device=device
600
- )
601
  self.context_k_proj = nn.Linear(
602
  self.llm_kv_dim, config.hidden_size, bias=False, device=device
603
  )
@@ -629,10 +595,6 @@ class ActionExpert(nn.Module):
629
  for _ in range(config.num_layers)
630
  ]
631
  )
632
- for block in self.blocks:
633
- block.cross_attn.kv_proj.weight.requires_grad = False
634
- if block.cross_attn.kv_proj.bias is not None:
635
- block.cross_attn.kv_proj.bias.requires_grad = False
636
  self.final_layer = ActionExpertFinalLayer(
637
  config.hidden_size, config.max_action_dim
638
  )
@@ -643,8 +605,6 @@ class ActionExpert(nn.Module):
643
  if isinstance(module, nn.Linear):
644
  _init_linear(module)
645
  _init_linear(self.action_embed)
646
- _init_linear(self.state_encoder)
647
- self.state_norm.reset_parameters()
648
  _init_linear(self.context_k_proj)
649
  _init_linear(self.context_v_proj)
650
  if isinstance(self.context_norm, ActionExpertRMSNorm):
@@ -654,7 +614,6 @@ class ActionExpert(nn.Module):
654
  _init_linear(block.self_attn.qkv)
655
  _init_linear(block.self_attn.out_proj, scale=residual_scale)
656
  _init_linear(block.cross_attn.q_proj)
657
- _init_linear(block.cross_attn.kv_proj)
658
  _init_linear(block.cross_attn.out_proj, scale=residual_scale)
659
  _init_linear(block.mlp.up_proj)
660
  _init_linear(block.mlp.gate_proj)
@@ -680,19 +639,6 @@ class ActionExpert(nn.Module):
680
  x.shape[0], x.shape[1], self.config.num_heads, self.action_head_dim
681
  )
682
 
683
- def _encode_states(self, states: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
684
- if states is None:
685
- return None
686
- if states.dim() == 2:
687
- states = states.unsqueeze(1)
688
- if states.shape[-1] != self.hidden_size:
689
- feat_dim = states.shape[-1]
690
- if feat_dim < self.hidden_size:
691
- states = F.pad(states, (0, self.hidden_size - feat_dim))
692
- else:
693
- states = states[..., : self.hidden_size]
694
- return self.state_norm(self.state_encoder(states))
695
-
696
  def _time_conditioning(self, timesteps: torch.Tensor) -> torch.Tensor:
697
  conditioning = self.time_embed[0](timesteps)
698
  first_linear = self.time_embed[1]
@@ -709,7 +655,6 @@ class ActionExpert(nn.Module):
709
  def _prepare_kv_context(
710
  self,
711
  encoder_kv_states: Sequence[Tuple[torch.Tensor, torch.Tensor]],
712
- encoded_states: Optional[torch.Tensor],
713
  ) -> Sequence[Tuple[torch.Tensor, torch.Tensor]]:
714
  if len(encoder_kv_states) != len(self.blocks):
715
  raise ValueError(
@@ -717,17 +662,9 @@ class ActionExpert(nn.Module):
717
  f"got {len(encoder_kv_states)}."
718
  )
719
  kv_contexts = []
720
- state_heads = (
721
- self._reshape_hidden_to_heads(encoded_states)
722
- if encoded_states is not None
723
- else None
724
- )
725
  for block, (k_in, v_in) in zip(self.blocks, encoder_kv_states):
726
  k_ctx = self._project_kv_tensor(k_in, self.context_k_proj)
727
  v_ctx = self._project_kv_tensor(v_in, self.context_v_proj)
728
- if state_heads is not None:
729
- k_ctx = torch.cat([k_ctx, state_heads], dim=1)
730
- v_ctx = torch.cat([v_ctx, state_heads], dim=1)
731
  k_norm = block.cross_attn.k_norm
732
  if k_norm is not None:
733
  k_ctx = k_norm(k_ctx.transpose(1, 2)).transpose(1, 2)
@@ -737,24 +674,12 @@ class ActionExpert(nn.Module):
737
  @staticmethod
738
  def _build_cross_attention_mask(
739
  encoder_attention_mask: Optional[torch.Tensor],
740
- encoded_states: Optional[torch.Tensor],
741
  batch_size: int,
742
  dtype: torch.dtype,
743
  ) -> Optional[torch.Tensor]:
744
- state_seq_len = 0 if encoded_states is None else encoded_states.shape[1]
745
  if encoder_attention_mask is None:
746
  return None
747
  mask = encoder_attention_mask[:, None, None, :].to(dtype=dtype)
748
- if state_seq_len > 0:
749
- ones = torch.ones(
750
- batch_size,
751
- 1,
752
- 1,
753
- state_seq_len,
754
- device=mask.device,
755
- dtype=mask.dtype,
756
- )
757
- mask = torch.cat([mask, ones], dim=-1)
758
  return (1.0 - mask) * torch.finfo(dtype).min
759
 
760
  def _build_self_attention_mask(
@@ -792,7 +717,11 @@ class ActionExpert(nn.Module):
792
  device: torch.device,
793
  dtype: torch.dtype,
794
  ) -> ActionExpertContext:
795
- encoded_states = self._encode_states(state_embeddings)
 
 
 
 
796
  valid_action = None
797
  if action_attention_mask is not None:
798
  valid_action = action_attention_mask.to(
@@ -805,10 +734,9 @@ class ActionExpert(nn.Module):
805
  device=device,
806
  dtype=dtype,
807
  )
808
- kv_contexts = self._prepare_kv_context(encoder_kv_states, encoded_states)
809
  cross_mask = self._build_cross_attention_mask(
810
  encoder_attention_mask,
811
- encoded_states,
812
  batch_size,
813
  dtype,
814
  )
 
315
  ActionExpertRMSNorm(self.head_dim, eps=qk_norm_eps) if qk_norm else None
316
  )
317
  self.q_proj = nn.Linear(hidden_size, hidden_size)
 
318
  self.out_proj = nn.Linear(hidden_size, hidden_size)
319
  self.out_drop = nn.Dropout(proj_dropout)
320
 
 
 
 
 
 
 
 
321
  def _as_heads(self, x: torch.Tensor) -> torch.Tensor:
322
  if x.dim() == 4:
323
  if x.shape[2] == self.num_heads:
 
353
  self,
354
  x: torch.Tensor,
355
  *,
356
+ kv_k: torch.Tensor,
357
+ kv_v: torch.Tensor,
 
358
  attn_mask: Optional[torch.Tensor] = None,
359
  ) -> torch.Tensor:
 
 
 
 
360
  bsz, tgt_len, _ = x.shape
361
  q = self.q_proj(x).view(bsz, tgt_len, self.num_heads, self.head_dim)
362
+ k = self._as_heads(kv_k)
363
+ v = self._as_heads(kv_v)
 
 
 
 
 
 
 
 
 
 
 
 
364
  q = q.transpose(1, 2)
365
  k = k.transpose(1, 2)
366
+ if self.q_norm is not None:
367
+ q = self.q_norm(q)
 
 
 
368
  q = q.transpose(1, 2)
369
  k = k.transpose(1, 2)
370
  out = self._attention(q, k, v, attn_mask=attn_mask)
 
564
  self.action_embed = nn.Linear(
565
  config.max_action_dim, config.hidden_size, device=device
566
  )
 
 
 
 
 
 
567
  self.context_k_proj = nn.Linear(
568
  self.llm_kv_dim, config.hidden_size, bias=False, device=device
569
  )
 
595
  for _ in range(config.num_layers)
596
  ]
597
  )
 
 
 
 
598
  self.final_layer = ActionExpertFinalLayer(
599
  config.hidden_size, config.max_action_dim
600
  )
 
605
  if isinstance(module, nn.Linear):
606
  _init_linear(module)
607
  _init_linear(self.action_embed)
 
 
608
  _init_linear(self.context_k_proj)
609
  _init_linear(self.context_v_proj)
610
  if isinstance(self.context_norm, ActionExpertRMSNorm):
 
614
  _init_linear(block.self_attn.qkv)
615
  _init_linear(block.self_attn.out_proj, scale=residual_scale)
616
  _init_linear(block.cross_attn.q_proj)
 
617
  _init_linear(block.cross_attn.out_proj, scale=residual_scale)
618
  _init_linear(block.mlp.up_proj)
619
  _init_linear(block.mlp.gate_proj)
 
639
  x.shape[0], x.shape[1], self.config.num_heads, self.action_head_dim
640
  )
641
 
 
 
 
 
 
 
 
 
 
 
 
 
 
642
  def _time_conditioning(self, timesteps: torch.Tensor) -> torch.Tensor:
643
  conditioning = self.time_embed[0](timesteps)
644
  first_linear = self.time_embed[1]
 
655
  def _prepare_kv_context(
656
  self,
657
  encoder_kv_states: Sequence[Tuple[torch.Tensor, torch.Tensor]],
 
658
  ) -> Sequence[Tuple[torch.Tensor, torch.Tensor]]:
659
  if len(encoder_kv_states) != len(self.blocks):
660
  raise ValueError(
 
662
  f"got {len(encoder_kv_states)}."
663
  )
664
  kv_contexts = []
 
 
 
 
 
665
  for block, (k_in, v_in) in zip(self.blocks, encoder_kv_states):
666
  k_ctx = self._project_kv_tensor(k_in, self.context_k_proj)
667
  v_ctx = self._project_kv_tensor(v_in, self.context_v_proj)
 
 
 
668
  k_norm = block.cross_attn.k_norm
669
  if k_norm is not None:
670
  k_ctx = k_norm(k_ctx.transpose(1, 2)).transpose(1, 2)
 
674
  @staticmethod
675
  def _build_cross_attention_mask(
676
  encoder_attention_mask: Optional[torch.Tensor],
 
677
  batch_size: int,
678
  dtype: torch.dtype,
679
  ) -> Optional[torch.Tensor]:
 
680
  if encoder_attention_mask is None:
681
  return None
682
  mask = encoder_attention_mask[:, None, None, :].to(dtype=dtype)
 
 
 
 
 
 
 
 
 
 
683
  return (1.0 - mask) * torch.finfo(dtype).min
684
 
685
  def _build_self_attention_mask(
 
717
  device: torch.device,
718
  dtype: torch.dtype,
719
  ) -> ActionExpertContext:
720
+ if state_embeddings is not None:
721
+ raise ValueError(
722
+ "MolmoAct2 HF action expert supports only discrete state tokens. "
723
+ "Continuous state embeddings are not supported."
724
+ )
725
  valid_action = None
726
  if action_attention_mask is not None:
727
  valid_action = action_attention_mask.to(
 
734
  device=device,
735
  dtype=dtype,
736
  )
737
+ kv_contexts = self._prepare_kv_context(encoder_kv_states)
738
  cross_mask = self._build_cross_attention_mask(
739
  encoder_attention_mask,
 
740
  batch_size,
741
  dtype,
742
  )