hungchiayu commited on
Commit
c17e96b
·
verified ·
1 Parent(s): 11e55e3

Upload modelling_expertv2.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modelling_expertv2.py +913 -0
modelling_expertv2.py ADDED
@@ -0,0 +1,913 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch import nn
4
+ import copy
5
+
6
+
7
+
8
+ def get_intermediate_size(hidden_dim, ffn_dim_multiplier=4, multiple_of=256):
9
+ hidden_dim = int(2 * hidden_dim / 3)
10
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
11
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
12
+ return hidden_dim
13
+
14
+ import torch.nn.functional as F # noqa: N812
15
+ import torch
16
+ from typing import Optional,Callable,Dict,Any
17
+ from torch import nn
18
+ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLAttention,apply_multimodal_rotary_pos_emb,eager_attention_forward,repeat_kv
19
+ from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLTextConfig
20
+ from transformers import Qwen2_5_VLTextModel,Qwen2_5_VLForConditionalGeneration
21
+ from transformers.cache_utils import Cache
22
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
23
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
24
+ from transformers.processing_utils import Unpack
25
+ from transformers import AutoProcessor
26
+ from einops import rearrange, repeat
27
+ from qwen_vl_utils import process_vision_info
28
+ import PIL
29
+ import json
30
+ import math
31
+ import numpy as np
32
+ from huggingface_hub import hf_hub_download
33
+
34
+ def create_sinusoidal_pos_embedding(
35
+ time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
36
+ ):
37
+ """Computes sine-cosine positional embedding vectors for scalar positions."""
38
+ if dimension % 2 != 0:
39
+ raise ValueError(f"dimension ({dimension}) must be divisible by 2")
40
+
41
+ if time.ndim != 1:
42
+ raise ValueError("The time tensor is expected to be of shape `(batch_size, )`.")
43
+
44
+ dtype = torch.float32
45
+ fraction = torch.linspace(0.0, 1.0, dimension // 2, dtype=dtype, device=device)
46
+ period = min_period * (max_period / min_period) ** fraction
47
+
48
+ # Compute the outer product
49
+ scaling_factor = 1.0 / period * 2 * math.pi
50
+ sin_input = scaling_factor[None, :] * time[:, None]
51
+ pos_emb = torch.cat([torch.sin(sin_input), torch.cos(sin_input)], dim=1)
52
+ return pos_emb
53
+
54
+ def apply_rope(x, positions, max_wavelength=10_000):
55
+ """
56
+ Applies RoPE positions [B, L] to x [B, L, H, D].
57
+ """
58
+ d_half = x.shape[-1] // 2
59
+ device = x.device
60
+ dtype = x.dtype
61
+ x = x.to(torch.float32)
62
+
63
+ freq_exponents = (2.0 / x.shape[-1]) * torch.arange(d_half, dtype=torch.float32, device=device)
64
+ timescale = max_wavelength**freq_exponents
65
+ radians = positions[..., None].to(torch.float32) / timescale[None, None, :].to(torch.float32)
66
+
67
+ radians = radians[..., None, :]
68
+
69
+ sin = torch.sin(radians) # .to(dtype=dtype)
70
+ cos = torch.cos(radians) # .to(dtype=dtype)
71
+
72
+ x1, x2 = x.split(d_half, dim=-1)
73
+ res = torch.empty_like(x)
74
+ res[..., :d_half] = x1 * cos - x2 * sin
75
+ res[..., d_half:] = x2 * cos + x1 * sin
76
+
77
+ return res.to(dtype)
78
+
79
+ def make_att_2d_masks(pad_masks, att_masks):
80
+ """Copied from big_vision.
81
+
82
+ Tokens can attend to valid inputs tokens which have a cumulative mask_ar
83
+ smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
84
+ setup several types of attention, for example:
85
+
86
+ [[1 1 1 1 1 1]]: pure causal attention.
87
+
88
+ [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
89
+ themselves and the last 3 tokens have a causal attention. The first
90
+ entry could also be a 1 without changing behaviour.
91
+
92
+ [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
93
+ block can attend all previous blocks and all tokens on the same block.
94
+
95
+ Args:
96
+ input_mask: bool[B, N] true if its part of the input, false if padding.
97
+ mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
98
+ it and 0 where it shares the same attention mask as the previous token.
99
+ """
100
+ if att_masks.ndim != 2:
101
+ raise ValueError(att_masks.ndim)
102
+ if pad_masks.ndim != 2:
103
+ raise ValueError(pad_masks.ndim)
104
+
105
+ cumsum = torch.cumsum(att_masks, dim=1)
106
+ att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
107
+ pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
108
+ att_2d_masks = att_2d_masks & pad_2d_masks
109
+ return att_2d_masks
110
+
111
+ class Qwen2_5_VLMoTAttention(Qwen2_5_VLAttention):
112
+ """
113
+
114
+ """
115
+
116
+ def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: Optional[int] = None):
117
+ super().__init__(config,layer_idx)
118
+
119
+
120
+ def forward(
121
+ self,
122
+ hidden_states: torch.Tensor,
123
+ attention_mask: Optional[torch.Tensor] = None,
124
+ position_ids: Optional[torch.LongTensor] = None,
125
+ past_key_value: Optional[Cache] = None,
126
+ output_attentions: bool = False,
127
+ use_cache: bool = False,
128
+ cache_position: Optional[torch.LongTensor] = None,
129
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
130
+ fill_kv_cache=True,
131
+ **kwargs: Unpack[FlashAttentionKwargs],
132
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
133
+
134
+ bsz, q_len, _ = hidden_states.size()
135
+
136
+ query_states = self.q_proj(hidden_states)
137
+ key_states = self.k_proj(hidden_states)
138
+ value_states = self.v_proj(hidden_states)
139
+
140
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
141
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
142
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
143
+
144
+
145
+ #cos, sin = position_embeddings
146
+
147
+ ## Since our action chunk is 1d time series, we do not need multimodal rope. Switch to normal rope instead
148
+ #query_states, key_states = apply_multimodal_rotary_pos_emb(
149
+ # query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
150
+ #)
151
+ query_states = rearrange(query_states, 'b h s d -> b s h d')
152
+ query_states = apply_rope(query_states,position_ids)
153
+ query_states = rearrange(query_states, 'b s h d -> b h s d')
154
+
155
+ key_states = rearrange(key_states, 'b h s d -> b s h d')
156
+ key_states = apply_rope(key_states,position_ids)
157
+ key_states = rearrange(key_states, 'b s h d -> b h s d')
158
+
159
+
160
+ if use_cache:
161
+
162
+ past_key_state = past_key_value[self.layer_idx][0]
163
+ past_value_state = past_key_value[self.layer_idx][1]
164
+
165
+ key_states = torch.cat([past_key_state, key_states], dim=2)
166
+ # print(key_states.dtype)
167
+ value_states = torch.cat(
168
+ [past_value_state, value_states], dim=2
169
+ )
170
+ key_states = key_states.to(dtype=query_states.dtype)
171
+ value_states = value_states.to(dtype=query_states.dtype)
172
+ #print("New K shape",key_states.shape)
173
+ #print("New V shape",value_states.shape)
174
+
175
+
176
+
177
+ #if past_key_value is not None and not fill_kv_cache: ## Only update KV cache if fill_kv_cache is False
178
+ #cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
179
+ # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
180
+
181
+ attention_interface: Callable = eager_attention_forward
182
+ if self.config._attn_implementation != "eager":
183
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
184
+ #print("New query shape",query_states.shape)
185
+
186
+
187
+ #attention_mask = torch.ones()
188
+ ## I need to check if is_casual is default to True here. Is casual will automatically create an attention mask and I do not want that to happen.
189
+ #print(position_ids)
190
+ #print(attention_mask.shape)
191
+
192
+ attn_output, attn_weights = attention_interface(
193
+ self,
194
+ query_states,
195
+ key_states,
196
+ value_states,
197
+ attention_mask,
198
+ dropout=0.0 if not self.training else self.attention_dropout,
199
+ scaling=self.scaling,
200
+ sliding_window=self.sliding_window,
201
+ position_ids=position_ids, # pass positions for FA2
202
+ **kwargs,
203
+ )
204
+
205
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
206
+ attn_output = self.o_proj(attn_output)
207
+ return attn_output, attn_weights
208
+ from transformers.modeling_outputs import BaseModelOutputWithPast
209
+ class Qwen2_5_VLAExpert(Qwen2_5_VLTextModel):
210
+
211
+
212
+
213
+ def __init__(self,config):
214
+ super().__init__(config)
215
+
216
+
217
+
218
+ def forward(self,
219
+ expert_attention_mask: Optional[torch.Tensor] = None,
220
+ position_ids: Optional[torch.LongTensor] = None,
221
+ vlm_key_values: Optional[Cache] = None,
222
+ inputs_embeds: Optional[torch.FloatTensor] = None,
223
+ use_cache: Optional[bool] = None,
224
+ cache_position: Optional[torch.LongTensor] = None,
225
+ output_attentions: Optional[bool] = None,
226
+ output_hidden_states: Optional[bool] = None,
227
+ return_dict: Optional[bool] = None,
228
+ **kwargs: Unpack[FlashAttentionKwargs],):
229
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
230
+ output_hidden_states = (
231
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
232
+ )
233
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
234
+
235
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
236
+
237
+
238
+ if self.gradient_checkpointing and self.training:
239
+ if use_cache:
240
+ logger.warning_once(
241
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
242
+ )
243
+ use_cache = False
244
+
245
+ if inputs_embeds is None:
246
+ raise ValueError("You must specify exactly inputs_embeds")
247
+ # torch.jit.trace() doesn't support cache objects in the output
248
+ if vlm_key_values is None:
249
+ raise ValueError("You must specify vlm_cache")
250
+
251
+
252
+
253
+
254
+ hidden_states = inputs_embeds
255
+
256
+ # create position embeddings to be shared across the decoder layers
257
+ #position_embeddings = self.rotary_emb(hidden_states, position_ids)
258
+
259
+ # decoder layers
260
+ all_hidden_states = () if output_hidden_states else None
261
+ all_self_attns = () if output_attentions else None
262
+
263
+ for decoder_layer in self.layers:
264
+ if output_hidden_states:
265
+ all_hidden_states += (hidden_states,)
266
+
267
+ layer_outputs = decoder_layer(
268
+ hidden_states,
269
+ attention_mask=expert_attention_mask,
270
+ position_ids=position_ids,
271
+ past_key_value=vlm_key_values,
272
+ output_attentions=output_attentions,
273
+ use_cache=use_cache,
274
+ cache_position=cache_position,
275
+ position_embeddings=None,
276
+ **kwargs,
277
+ )
278
+
279
+ hidden_states = layer_outputs[0]
280
+
281
+ if output_attentions:
282
+ all_self_attns += (layer_outputs[1],)
283
+
284
+ hidden_states = self.norm(hidden_states)
285
+
286
+ # add hidden states from the last decoder layer
287
+ if output_hidden_states:
288
+ all_hidden_states += (hidden_states,)
289
+
290
+ if not return_dict:
291
+ return tuple(
292
+ v for v in [hidden_states, vlm_key_values, all_hidden_states, all_self_attns] if v is not None
293
+ )
294
+ return BaseModelOutputWithPast(
295
+ last_hidden_state=hidden_states,
296
+ past_key_values=vlm_key_values,
297
+ hidden_states=all_hidden_states,
298
+ attentions=all_self_attns,
299
+ )
300
+
301
+ import tensorflow as tf
302
+ import dlimp as dl
303
+ import PIL.Image as Image
304
+
305
+
306
+ def resize_image(image1):
307
+ #image1 = ds_combined[0]['observation.images.scene']
308
+ #image1 = image1.reshape(480,640,3)
309
+ image1 = tf.cast(image1*255, dtype=tf.uint8)
310
+ image1 = image1.numpy().transpose(1,2,0)
311
+ image1 = dl.transforms.resize_image(image1, size=(224,224))
312
+
313
+ image1 = Image.fromarray(image1.numpy())
314
+ return image1
315
+
316
+ class VLAWithExpert(nn.Module):
317
+
318
+
319
+ _ACTION_TOKEN_MIN = 151665
320
+ _ACTION_TOKEN_MAX = 153712
321
+
322
+
323
+ def __init__(self,config=None,device=None):
324
+ super().__init__()
325
+
326
+
327
+ self.vlm = Qwen2_5_VLForConditionalGeneration.from_pretrained(
328
+ "declare-lab/nora-long",
329
+ torch_dtype=torch.bfloat16,
330
+ attn_implementation="sdpa",
331
+ )
332
+ if config is not None:
333
+ self.config = config
334
+ else:
335
+ self.config = {'max_action_dim':7,"max_state_dim":8}
336
+
337
+
338
+ print("Loading expert model...")
339
+
340
+ self.lm_expert_config = copy.deepcopy(self.vlm.config.text_config)
341
+
342
+ #lm_expert_config = copy.deepcopy(model.config.text_config)
343
+ self.processor = AutoProcessor.from_pretrained(
344
+ "declare-lab/nora", trust_remote_code=True
345
+ )
346
+ self.fast_tokenizer = fast_tokenizer = AutoProcessor.from_pretrained(
347
+ "physical-intelligence/fast", trust_remote_code=True
348
+ )
349
+ self.fast_tokenizer.action_dim = 7
350
+ self.fast_tokenizer.time_horizon = 5
351
+ hidden_size = self.lm_expert_config.hidden_size
352
+ expert_width_multiplier = 0.375
353
+ self.lm_expert_config.hidden_size = int(hidden_size * expert_width_multiplier) # hidden_size // 2
354
+ self.lm_expert_config.intermediate_size = get_intermediate_size(int(hidden_size * expert_width_multiplier))
355
+ self.lm_expert_config.num_hidden_layers = self.vlm.config.num_hidden_layers
356
+ self.lm_expert_config.num_attention_heads = 6
357
+
358
+ self.action_expert = Qwen2_5_VLAExpert._from_config(self.lm_expert_config,torch_dtype=torch.bfloat16)
359
+ self.action_chunk_length = 5
360
+
361
+ self.device = self.vlm.device
362
+ # Replace the action expert's attention layers
363
+
364
+ self._replace_action_expert_attention()
365
+ self.action_expert.embed_tokens = None
366
+ self.vlm_kv_cache = None
367
+
368
+
369
+ # self.state_proj = nn.Linear(
370
+ # self.config['max_state_dim'], hidden_size
371
+ # )
372
+ self.action_in_proj = nn.Linear(self.config['max_action_dim'],self.lm_expert_config.hidden_size)
373
+ self.action_out_proj = nn.Linear(self.lm_expert_config.hidden_size, self.config['max_action_dim'])
374
+ self.action_time_mlp_in = nn.Linear(
375
+ self.lm_expert_config.hidden_size * 2, self.lm_expert_config.hidden_size
376
+ )
377
+ self.action_time_mlp_out = nn.Linear(
378
+ self.lm_expert_config.hidden_size, self.lm_expert_config.hidden_size
379
+ )
380
+ self.state_emb = nn.Linear(self.config['max_action_dim'], self.lm_expert_config.hidden_size)
381
+
382
+ self.device = self.vlm.device
383
+ print(f"*** Loading normalization stats from HF Hub ***")
384
+ norm_stats_path = hf_hub_download(repo_id='declare-lab/nora', filename="norm_stats.json")
385
+ with open(norm_stats_path, "r") as f:
386
+ self.norm_stats = json.load(f)
387
+
388
+ libero_stats = hf_hub_download(repo_id='moojink/openvla-7b-oft-finetuned-libero-spatial-object-goal-10', filename="dataset_statistics.json")
389
+ with open(libero_stats, "r") as f:
390
+ self.norm_stats.update(json.load(f))
391
+
392
+
393
+
394
+
395
+
396
+
397
+
398
+
399
+
400
+ def sample_noise(self, shape, device,dtype=torch.float32):
401
+ noise = torch.normal(
402
+ mean=0.0,
403
+ std=1.0,
404
+ size=shape,
405
+ dtype=dtype,
406
+ device=device,
407
+ )
408
+ return noise
409
+ def sample_time(self, bsize, device,dtype=torch.float32):
410
+ beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
411
+ time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=dtype)
412
+ time = time_beta * 0.999 + 0.001
413
+ return time
414
+
415
+ def _replace_action_expert_attention(self):
416
+ """
417
+ Iterate through the model's layers and replace the default
418
+ Qwen2_5_VLAttention with our custom Qwen2_5_VLMoTAttention.
419
+ """
420
+ for i, layer in enumerate(self.action_expert.layers):
421
+ layer.self_attn = Qwen2_5_VLMoTAttention(
422
+ config=self.action_expert.config,
423
+ layer_idx=i
424
+ ).to(self.action_expert.dtype)
425
+ layer.self_attn.to(self.action_expert.device)
426
+
427
+
428
+ def denoise_step(
429
+ self,
430
+ x_t: torch.Tensor,
431
+ timestep: torch.Tensor,
432
+ states,
433
+ vlm_kv_cache: tuple,
434
+ full_2d_attn_mask: torch.Tensor):
435
+ """
436
+ Applies one denoising step to the noisy action `x_t` at a given `timestep`,
437
+ conditioned on the VLM's output cache.
438
+
439
+ This function is derived from the main `forward` pass, encapsulating the
440
+ logic for a single step in the diffusion sampling process.
441
+
442
+ Args:
443
+ self: The instance of the model class.
444
+ x_t (torch.Tensor): The noisy action tensor from the previous step.
445
+ Shape: (batch_size, action_chunk_length, action_dim).
446
+ timestep (torch.Tensor): The current timestep for each sample in the batch.
447
+ Shape: (batch_size,).
448
+ vlm_kv_cache (tuple): The pre-computed key-value cache from the VLM,
449
+ used as conditioning.
450
+ vlm_pad_mask (torch.Tensor): The padding mask for the VLM inputs, required
451
+ to build the cross-attention mask.
452
+ Shape: (batch_size, vlm_seq_len).
453
+
454
+ Returns:
455
+ torch.Tensor: The predicted noise `u_t` (epsilon).
456
+ Shape: (batch_size, action_chunk_length, action_dim).
457
+ """
458
+ device = x_t.device
459
+ bsz = x_t.shape[0]
460
+
461
+ # 1. Embed the noisy action `x_t`
462
+ x_t = x_t.to(dtype=self.vlm.dtype)
463
+
464
+ action_input_embeds = self.action_in_proj(x_t)
465
+
466
+ # 2. Create sinusoidal time embeddings from the current timestep
467
+ time_emb = create_sinusoidal_pos_embedding(
468
+ timestep,
469
+ self.lm_expert_config.hidden_size,
470
+ 4e-3, # Values from your forward pass
471
+ 4.0,
472
+ device=device,
473
+ )
474
+ time_emb = time_emb.type(dtype=x_t.dtype)
475
+ # Expand time embedding to match the action embedding dimensions
476
+ time_emb = time_emb[:, None, :].expand_as(action_input_embeds)
477
+
478
+ # 3. Combine action and time embeddings and process through MLPs
479
+ action_time_emb = torch.cat([action_input_embeds, time_emb], dim=2)
480
+ action_time_emb = self.action_time_mlp_in(action_time_emb)
481
+ action_time_emb = F.silu(action_time_emb) # swish activation
482
+ action_time_emb = self.action_time_mlp_out(action_time_emb)
483
+ if states is not None:
484
+ states_embed = self.state_emb(states)
485
+ # print(states_embed.shape,action_input_embeds.shape)
486
+ states_embed = states_embed.unsqueeze(1).expand_as(action_input_embeds)
487
+ action_time_emb += states_embed
488
+
489
+
490
+ # 4. Construct the attention mask for the action expert.
491
+ # The expert needs to attend to the VLM context and its own action inputs.
492
+
493
+
494
+ # The expert's queries originate from the action sequence, so we slice the mask accordingly.
495
+ # It can attend to the full VLM context and the action sequence.
496
+ expert_attention_mask = full_2d_attn_mask[:, -self.action_chunk_length:, :]
497
+
498
+ # 5. Prepare position_ids for the expert.
499
+ # Note: This implementation mirrors your forward pass, where position_ids for the
500
+ # expert restart from 0.
501
+ position_ids = torch.arange(self.action_chunk_length, device=device)
502
+
503
+ # 6. Call the action expert with the prepared inputs and VLM cache.
504
+ expert_output = self.action_expert(
505
+ inputs_embeds=action_time_emb,
506
+ expert_attention_mask=expert_attention_mask.unsqueeze(1).bool(), # Add head dim
507
+ position_ids=position_ids,
508
+ vlm_key_values=vlm_kv_cache,
509
+ use_cache=True, # As in the original forward pass
510
+ )
511
+
512
+ # 7. Project the expert's output to get the final noise prediction.
513
+ velocity = self.action_out_proj(expert_output.last_hidden_state)
514
+
515
+ return velocity
516
+
517
+ def sample_fast_tokens(self,image,image2=None,instruction=None,states=None,unnormalize=False,do_sample=False):
518
+ device = self.vlm.device
519
+ states = states.to(device)
520
+ #states =
521
+ #print(type(image))
522
+ image = resize_image(image) ## IMPORTANT. ENSURE IMAGE RESIZING METHOD IS CONSISTENT WITH PRETRAINIGN
523
+ #if not isinstance(image, PIL.Image.Image):
524
+ # image = PIL.Image.fromarray(image)
525
+ # Construct messages in the expected chat format. Note that nora expects image of size 224 by 224
526
+
527
+
528
+ #image = resize_image(image)
529
+ if image2 is not None:
530
+ image2 = resize_image(image2)
531
+ #if not isinstance(image, PIL.Image.Image):
532
+ #image = PIL.Image.fromarray(image)
533
+ # Construct messages in the expected chat format. Note that nora expects image of size 224 by 224
534
+
535
+
536
+ messages = [
537
+ {
538
+ "role": "user",
539
+ "content": [
540
+ {
541
+ "type": "image",
542
+ "image": image,
543
+ "resized_height": 224,
544
+ "resized_width": 224,
545
+ },{
546
+ "type": "image", "image": image2,
547
+ "resized_height": 224,
548
+ "resized_width": 224,
549
+ },
550
+
551
+ {"type": "text", "text": instruction},
552
+ ],
553
+ }
554
+ ]
555
+ else:
556
+ messages = [
557
+ {
558
+ "role": "user",
559
+ "content": [
560
+ {
561
+ "type": "image",
562
+ "image": image,
563
+ "resized_height": 224,
564
+ "resized_width": 224,
565
+ } ,
566
+ {"type": "text", "text": instruction},
567
+ ],
568
+ }
569
+ ]
570
+ # Apply chat template to get the text input for the model
571
+ text = self.processor.apply_chat_template(
572
+ messages, tokenize=False, add_generation_prompt=True
573
+ )
574
+
575
+ # Process vision information (depends on your process_vision_info function)
576
+ image_inputs, video_inputs = process_vision_info(messages)
577
+
578
+ # Prepare inputs for the model using the main processor
579
+ #image_inputs, video_inputs = process_vision_info(messages)
580
+ inputs = self.processor(
581
+ text=[text],
582
+ images=image_inputs,
583
+ videos=video_inputs,
584
+ padding=True,
585
+ return_tensors="pt",
586
+ )
587
+
588
+ # Move inputs to GPU
589
+
590
+ inputs = {k: v.to(device) for k, v in inputs.items()}
591
+
592
+ generated_ids = self.vlm.generate(**inputs,do_sample=True,temperature=1.0)
593
+
594
+
595
+
596
+ # --- Extract and Decode Action ---
597
+ # Find the indices of tokens within the action token range
598
+
599
+
600
+ start_idx = (self._ACTION_TOKEN_MIN <= generated_ids[0]) & (generated_ids[0] <= self._ACTION_TOKEN_MAX)
601
+ start_idx = torch.where(start_idx)[0]
602
+
603
+ if len(start_idx) > 0:
604
+ start_index = start_idx[0].item()
605
+ else:
606
+ start_index = None # or -1 to indicate not found
607
+
608
+
609
+ # Extract the first action token ID
610
+
611
+ # Decode the action token using the fast tokenizer
612
+ # The token ID needs to be map back to the range expected by the fast tokenizer decoder
613
+
614
+
615
+
616
+ output_action = self.fast_tokenizer.decode([generated_ids[0][start_idx] - self._ACTION_TOKEN_MIN])
617
+ return output_action
618
+
619
+
620
+ @torch.no_grad()
621
+ def sample_actions(self, image,image2=None,instruction=None,num_steps:int = 25,states=None,unnorm_key='libero_10',unnormalize=True):
622
+ """
623
+ Generates actions by running the full diffusion sampling process.
624
+
625
+ This function first computes the VLM's key-value cache to use as a
626
+ conditioning context. It then uses an iterative Euler-method-based
627
+ sampler, calling `denoise_step` at each timestep to refine a noise
628
+ tensor into a final action.
629
+
630
+ Args:
631
+ self: The instance of the model class.
632
+ vlm_inputs (dict): A dictionary containing the inputs for the VLM,
633
+ e.g., {'input_ids': ..., 'attention_mask': ...}.
634
+ noise (Tensor, optional): An initial noise tensor to start the
635
+ sampling from. If None, it will be
636
+ sampled randomly. Defaults to None.
637
+ Shape: (batch_size, action_chunk_length, action_dim).
638
+
639
+ Returns:
640
+ Tensor: The final, denoised action tensor.
641
+ Shape: (batch_size, action_chunk_length, action_dim).
642
+ """
643
+ #vlm_inputs = self.prepare_inputs_for_generation(image,instruction)
644
+
645
+
646
+ device = self.vlm.device
647
+ states = states.to(device)
648
+ #states =
649
+ #print(type(image))
650
+ image = resize_image(image) ## IMPORTANT. ENSURE IMAGE RESIZING METHOD IS CONSISTENT WITH PRETRAINIGN
651
+ #if not isinstance(image, PIL.Image.Image):
652
+ # image = PIL.Image.fromarray(image)
653
+ # Construct messages in the expected chat format. Note that nora expects image of size 224 by 224
654
+
655
+
656
+ #image = resize_image(image)
657
+ if image2 is not None:
658
+ image2 = resize_image(image2)
659
+ #if not isinstance(image, PIL.Image.Image):
660
+ #image = PIL.Image.fromarray(image)
661
+ # Construct messages in the expected chat format. Note that nora expects image of size 224 by 224
662
+
663
+
664
+ messages = [
665
+ {
666
+ "role": "user",
667
+ "content": [
668
+ {
669
+ "type": "image",
670
+ "image": image,
671
+ "resized_height": 224,
672
+ "resized_width": 224,
673
+ },{
674
+ "type": "image", "image": image2,
675
+ "resized_height": 224,
676
+ "resized_width": 224,
677
+ },
678
+
679
+ {"type": "text", "text": instruction},
680
+ ],
681
+ }
682
+ ]
683
+ else:
684
+ messages = [
685
+ {
686
+ "role": "user",
687
+ "content": [
688
+ {
689
+ "type": "image",
690
+ "image": image,
691
+ "resized_height": 224,
692
+ "resized_width": 224,
693
+ } ,
694
+ {"type": "text", "text": instruction},
695
+ ],
696
+ }
697
+ ]
698
+ # Apply chat template to get the text input for the model
699
+ text = self.processor.apply_chat_template(
700
+ messages, tokenize=False, add_generation_prompt=True
701
+ )
702
+
703
+ # Process vision information (depends on your process_vision_info function)
704
+ image_inputs, video_inputs = process_vision_info(messages)
705
+
706
+ # Prepare inputs for the model using the main processor
707
+ #image_inputs, video_inputs = process_vision_info(messages)
708
+ inputs = self.processor(
709
+ text=[text],
710
+ images=image_inputs,
711
+ videos=video_inputs,
712
+ padding=True,
713
+ return_tensors="pt",
714
+ )
715
+
716
+ # Move inputs to GPU
717
+
718
+ inputs = {k: v.to(device) for k, v in inputs.items()}
719
+
720
+
721
+
722
+
723
+ bsz = inputs['input_ids'].shape[0]
724
+
725
+
726
+
727
+
728
+ # 1. Pre-compute the VLM cache. This context is the conditioning for the
729
+ # entire denoising process and only needs to be computed once.
730
+ if self.vlm_kv_cache is None:
731
+ vlm_outputs = self.vlm(**inputs)
732
+ vlm_kv_cache = vlm_outputs.past_key_values
733
+ self.vlm_kv_cache = vlm_kv_cache
734
+
735
+ # The VLM's attention mask is its padding mask for the expert.
736
+
737
+ vlm_pad_mask = inputs['attention_mask'].clone()
738
+
739
+ # 2. Initialize the noisy action tensor `x_t`.
740
+
741
+ actions_shape = (bsz, self.action_chunk_length, self.config['max_action_dim'])
742
+ x_t = self.sample_noise(actions_shape, device=device,dtype=self.vlm.dtype)
743
+
744
+
745
+ # 3. Set up the time steps for the Euler solver.
746
+ # We will step from t=1 down to t=0.
747
+ #num_steps = self.config.num_steps
748
+ dt = -1.0 / num_steps
749
+ dt_tensor = torch.tensor(dt, dtype=self.vlm.dtype, device=device)
750
+ time = torch.tensor(1.0, dtype=self.vlm.dtype, device=device)
751
+ states = states.to(self.vlm.dtype)
752
+
753
+ # 4. Iteratively denoise using the Euler method.
754
+ # The loop continues as long as time is greater than or equal to zero.
755
+ action_pad_mask = torch.ones(bsz, self.action_chunk_length, device=device).bool()
756
+
757
+ # An all-zero attention mask for the action part allows for full bidirectional attention
758
+ # within the action chunk, as seen in the original forward pass.
759
+ action_attn_mask = torch.zeros(bsz, self.action_chunk_length, device=device).bool()
760
+
761
+ # Concatenate VLM (prefix) and action masks.
762
+ # The VLM's attention mask is its padding mask.
763
+ concat_pad_mask = torch.cat([vlm_pad_mask, action_pad_mask], dim=1)
764
+ concat_attn_mask = torch.cat([vlm_pad_mask, action_attn_mask], dim=1)
765
+
766
+ # Create the full 2D attention mask for the combined sequence.
767
+ full_2d_attn_mask = make_att_2d_masks(concat_pad_mask, concat_attn_mask)
768
+ while time >= -dt / 2: # Loop until t=0
769
+ with torch.no_grad():
770
+ # Expand the current time to match the batch size.
771
+ expanded_time = time.expand(bsz)
772
+
773
+ # Call the denoise_step function to predict the velocity v_t (or noise u_t).
774
+ # The function takes the current noisy action, timestep, and the
775
+ # pre-computed VLM cache and padding mask as input.
776
+ #print(expanded_time)
777
+ v_t = self.denoise_step(
778
+ x_t=x_t,
779
+ timestep=expanded_time,
780
+ states=states,
781
+ vlm_kv_cache=self.vlm_kv_cache,
782
+ full_2d_attn_mask=full_2d_attn_mask,
783
+ )
784
+
785
+ # 5. Apply the Euler integration step to update the action tensor.
786
+ # This moves the action slightly along the direction of the predicted velocity.
787
+ x_t += dt * v_t
788
+ time += dt
789
+
790
+ # 6. Return the final denoised action.
791
+ normalized_action = x_t.cpu().float().numpy()
792
+ #self.vlm_kv_cache = None
793
+ if unnormalize is False:
794
+
795
+ return normalized_action
796
+
797
+ action_stats = self._get_action_stats(unnorm_key)
798
+
799
+ mask = action_stats.get("mask", np.ones_like(action_stats["q01"], dtype=bool))
800
+ action_high, action_low = np.array(action_stats["q99"]), np.array(action_stats["q01"])
801
+
802
+ actions = np.where(
803
+ mask,
804
+ 0.5 * (normalized_action + 1) * (action_high - action_low) + action_low,
805
+ normalized_action,
806
+ )
807
+
808
+ return actions
809
+
810
+ def _get_action_stats(self, unnorm_key: str) -> Dict[str, Any]:
811
+ if unnorm_key not in self.norm_stats:
812
+ raise KeyError(
813
+ f"The `unnorm_key` '{unnorm_key}' is not in the set of available dataset statistics. "
814
+ f"Please choose from: {list(self.norm_stats.keys())}"
815
+ )
816
+ return self.norm_stats[unnorm_key]["action"]
817
+ def forward(self,vlm_inputs, actions,alpha=10.0,use_state=False,states=None ,**kwargs):
818
+ """
819
+ The main forward pass that uses the student model with the expert's cache.
820
+ """
821
+
822
+
823
+ # The magic happens here: we pass the expert cache into the student's forward call.
824
+ # This will require modifying how arguments are passed down.
825
+ ## Precompute the VLM cache with only VLM inputs/attention mask
826
+ ## Let the Qwen2_5 vlm settle its own attention mask.
827
+ device = self.vlm.device
828
+
829
+ vlm_outputs = self.vlm(
830
+ **vlm_inputs,
831
+ use_cache=True
832
+ )
833
+ vlm_kv_cache = vlm_outputs.past_key_values
834
+
835
+ ## Construct attention mask for the action expert.
836
+ ## The action expert should be able to attend to the VLM inputs and its own action inputs. ( Prefix + bidirectional attention)
837
+
838
+ bsz = vlm_inputs['input_ids'].shape[0]
839
+ vlm_pad_mask = vlm_inputs['expert_attention'].clone()
840
+ vlm_attn_mask = vlm_inputs['attention_mask'].clone()
841
+
842
+
843
+
844
+ actions = actions.to(self.vlm.dtype)
845
+ noise = self.sample_noise(actions.shape, actions.device,dtype=actions.dtype)
846
+
847
+
848
+ time = self.sample_time(actions.shape[0], actions.device,dtype=actions.dtype)
849
+
850
+
851
+
852
+ time_expanded = time[:, None, None]
853
+
854
+
855
+ x_t = time_expanded * noise + (1 - time_expanded) * actions
856
+ u_t = noise - actions
857
+ #x_t = x_t.to(self.vlm.dtype)
858
+ action_input_embeds = self.action_in_proj(x_t) ## Embed noisy action
859
+
860
+ time_emb = create_sinusoidal_pos_embedding(
861
+ time,
862
+ self.lm_expert_config.hidden_size,
863
+ 4e-3,
864
+ 4.0,
865
+ device=device,
866
+ )
867
+
868
+ time_emb = time_emb.type(dtype=actions.dtype)
869
+
870
+ time_emb = time_emb[:, None, :].expand_as(action_input_embeds)
871
+
872
+
873
+ action_time_emb = torch.cat([action_input_embeds, time_emb], dim=2) ## concat on the hidden size dim
874
+
875
+ action_time_emb = self.action_time_mlp_in(action_time_emb) ## simple linear layer to project back to hidden size dim
876
+ action_time_emb = F.silu(action_time_emb) # swish == silu
877
+ action_time_emb = self.action_time_mlp_out(action_time_emb) ##
878
+
879
+ if use_state:
880
+
881
+ states_embed = self.state_emb(states)
882
+
883
+ states_embed = states_embed.unsqueeze(1).expand_as(action_input_embeds)
884
+ action_time_emb += states_embed
885
+
886
+
887
+
888
+
889
+
890
+
891
+ action_pad_mask = torch.ones(bsz,self.action_chunk_length,device=device).bool()
892
+ action_attn_mask = torch.zeros(bsz,self.action_chunk_length,device=device).bool()
893
+
894
+ concat_action_mask = torch.cat([vlm_pad_mask,action_pad_mask],dim=1)
895
+ concat_attn_mask = torch.cat([vlm_attn_mask,action_attn_mask],dim=1)
896
+
897
+ attn = make_att_2d_masks(concat_action_mask,concat_attn_mask)
898
+ expert_attention_mask = attn[:, -self.action_chunk_length:, :]
899
+
900
+
901
+ position_ids = torch.arange(self.action_chunk_length,device=device)
902
+ expert_output = self.action_expert(inputs_embeds=action_time_emb,
903
+ expert_attention_mask=expert_attention_mask.unsqueeze(1).bool(),
904
+ position_ids= position_ids,
905
+ vlm_key_values=vlm_kv_cache,
906
+ use_cache=True)
907
+
908
+ action_out = self.action_out_proj(expert_output.last_hidden_state)
909
+ expert_loss = alpha*F.mse_loss(action_out, u_t, reduction='mean')
910
+
911
+ loss = expert_loss+ vlm_outputs.loss
912
+
913
+ return {'expert_loss': expert_loss,'combined_loss':loss,'vlm_loss':vlm_outputs.loss}