|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import inspect
|
|
|
import os
|
|
|
from typing import Optional, Tuple
|
|
|
|
|
|
import torch
|
|
|
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
|
|
from transformers.utils import is_flash_attn_greater_or_equal
|
|
|
|
|
|
from verl.utils.ulysses import (
|
|
|
gather_heads_scatter_seq,
|
|
|
gather_seq_scatter_heads,
|
|
|
get_ulysses_sequence_parallel_world_size,
|
|
|
validate_ulysses_config,
|
|
|
)
|
|
|
|
|
|
try:
|
|
|
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
|
|
|
|
|
_flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters)
|
|
|
except ImportError:
|
|
|
flash_attn_varlen_func = None
|
|
|
|
|
|
|
|
|
def get_rope_index(
|
|
|
processor,
|
|
|
input_ids: torch.Tensor,
|
|
|
image_grid_thw: Optional[torch.Tensor] = None,
|
|
|
video_grid_thw: Optional[torch.Tensor] = None,
|
|
|
second_per_grid_ts: Optional[torch.Tensor] = None,
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
) -> torch.Tensor:
|
|
|
"""
|
|
|
Gets the position ids for Qwen2-VL, it should be generated before sharding the sequence.
|
|
|
The batch dim has been removed and the input_ids should be a 1D tensor representing a single example.
|
|
|
https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L1546
|
|
|
"""
|
|
|
spatial_merge_size = processor.image_processor.merge_size
|
|
|
tokens_per_second = 2
|
|
|
image_token_id = processor.tokenizer.convert_tokens_to_ids("<|image_pad|>")
|
|
|
video_token_id = processor.tokenizer.convert_tokens_to_ids("<|video_pad|>")
|
|
|
vision_start_token_id = processor.tokenizer.convert_tokens_to_ids("<|vision_start|>")
|
|
|
if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
|
|
|
if attention_mask is None:
|
|
|
attention_mask = torch.ones_like(input_ids)
|
|
|
|
|
|
position_ids = torch.ones(3, input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device)
|
|
|
image_index, video_index = 0, 0
|
|
|
input_ids = input_ids[attention_mask == 1]
|
|
|
image_nums, video_nums = 0, 0
|
|
|
vision_start_indices = torch.argwhere(input_ids == vision_start_token_id)
|
|
|
vision_tokens = input_ids[vision_start_indices + 1]
|
|
|
image_nums = (vision_tokens == image_token_id).sum()
|
|
|
video_nums = (vision_tokens == video_token_id).sum()
|
|
|
input_tokens = input_ids.tolist()
|
|
|
llm_pos_ids_list: list = []
|
|
|
st = 0
|
|
|
remain_images, remain_videos = image_nums, video_nums
|
|
|
for _ in range(image_nums + video_nums):
|
|
|
if image_token_id in input_tokens and remain_images > 0:
|
|
|
ed_image = input_tokens.index(image_token_id, st)
|
|
|
else:
|
|
|
ed_image = len(input_tokens) + 1
|
|
|
if video_token_id in input_tokens and remain_videos > 0:
|
|
|
ed_video = input_tokens.index(video_token_id, st)
|
|
|
else:
|
|
|
ed_video = len(input_tokens) + 1
|
|
|
if ed_image < ed_video:
|
|
|
t, h, w = (
|
|
|
image_grid_thw[image_index][0],
|
|
|
image_grid_thw[image_index][1],
|
|
|
image_grid_thw[image_index][2],
|
|
|
)
|
|
|
second_per_grid_t = 0
|
|
|
image_index += 1
|
|
|
remain_images -= 1
|
|
|
ed = ed_image
|
|
|
else:
|
|
|
t, h, w = (
|
|
|
video_grid_thw[video_index][0],
|
|
|
video_grid_thw[video_index][1],
|
|
|
video_grid_thw[video_index][2],
|
|
|
)
|
|
|
second_per_grid_t = second_per_grid_ts[video_index] if second_per_grid_ts is not None else 1.0
|
|
|
|
|
|
video_index += 1
|
|
|
remain_videos -= 1
|
|
|
ed = ed_video
|
|
|
|
|
|
llm_grid_t, llm_grid_h, llm_grid_w = (
|
|
|
t.item(),
|
|
|
h.item() // spatial_merge_size,
|
|
|
w.item() // spatial_merge_size,
|
|
|
)
|
|
|
text_len = ed - st
|
|
|
|
|
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
|
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
|
|
|
|
|
t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w)
|
|
|
t_index = (t_index * second_per_grid_t * tokens_per_second).long().flatten()
|
|
|
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
|
|
|
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
|
|
|
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + text_len + st_idx)
|
|
|
st = ed + llm_grid_t * llm_grid_h * llm_grid_w
|
|
|
|
|
|
if st < len(input_tokens):
|
|
|
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
|
|
|
text_len = len(input_tokens) - st
|
|
|
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
|
|
|
|
|
|
llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
|
|
|
position_ids[..., attention_mask == 1] = llm_positions.to(position_ids.device)
|
|
|
else:
|
|
|
if attention_mask is not None:
|
|
|
position_ids = attention_mask.long().cumsum(-1) - 1
|
|
|
position_ids.masked_fill_(attention_mask == 0, 1)
|
|
|
position_ids = position_ids.unsqueeze(0).expand(3, -1).to(input_ids.device)
|
|
|
else:
|
|
|
position_ids = torch.arange(input_ids.shape[1], device=input_ids.device).view(1, -1).expand(3, -1)
|
|
|
|
|
|
return position_ids
|
|
|
|
|
|
|
|
|
def prepare_fa2_from_position_ids(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, position_ids: torch.Tensor):
|
|
|
query = query.view(-1, query.size(-2), query.size(-1))
|
|
|
key = key.view(-1, key.size(-2), key.size(-1))
|
|
|
value = value.view(-1, value.size(-2), value.size(-1))
|
|
|
position_ids = position_ids.flatten()
|
|
|
indices_q = torch.arange(position_ids.size(0), device=position_ids.device, dtype=torch.int32)
|
|
|
cu_seqlens = torch.cat(
|
|
|
(
|
|
|
indices_q[position_ids == 0],
|
|
|
torch.tensor(position_ids.size(), device=position_ids.device, dtype=torch.int32),
|
|
|
)
|
|
|
)
|
|
|
max_length = cu_seqlens.diff().max()
|
|
|
return (query, key, value, indices_q, (cu_seqlens, cu_seqlens), (max_length, max_length))
|
|
|
|
|
|
|
|
|
def flash_attention_forward(
|
|
|
query_states: torch.Tensor,
|
|
|
key_states: torch.Tensor,
|
|
|
value_states: torch.Tensor,
|
|
|
attention_mask: torch.Tensor,
|
|
|
query_length: int,
|
|
|
is_causal: bool = True,
|
|
|
position_ids: Optional[torch.Tensor] = None,
|
|
|
sliding_window: Optional[int] = None,
|
|
|
use_top_left_mask: bool = False,
|
|
|
deterministic: Optional[bool] = None,
|
|
|
**kwargs,
|
|
|
):
|
|
|
"""
|
|
|
Patches flash attention forward to handle 3D position ids in mrope. (3, batch_size, seq_length)
|
|
|
"""
|
|
|
causal = is_causal if not use_top_left_mask else is_causal and query_length != 1
|
|
|
|
|
|
|
|
|
use_sliding_windows = _flash_supports_window_size and sliding_window is not None and key_states.shape[1] > sliding_window
|
|
|
flash_kwargs = {"window_size": (sliding_window, sliding_window)} if use_sliding_windows else {}
|
|
|
|
|
|
if is_flash_attn_greater_or_equal("2.4.1"):
|
|
|
if deterministic is None:
|
|
|
deterministic = os.environ.get("FLASH_ATTENTION_DETERMINISTIC", "0") == "1"
|
|
|
flash_kwargs["deterministic"] = deterministic
|
|
|
|
|
|
if position_ids is not None and query_length != 1 and not (torch.diff(position_ids[0], dim=-1) >= 0).all():
|
|
|
batch_size = query_states.size(0)
|
|
|
query_states, key_states, value_states, _, cu_seq_lens, max_seq_lens = prepare_fa2_from_position_ids(query_states, key_states, value_states, position_ids[0])
|
|
|
cu_seqlens_q, cu_seqlens_k = cu_seq_lens
|
|
|
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
|
|
|
attn_output = flash_attn_varlen_func(
|
|
|
query_states,
|
|
|
key_states,
|
|
|
value_states,
|
|
|
cu_seqlens_q=cu_seqlens_q,
|
|
|
cu_seqlens_k=cu_seqlens_k,
|
|
|
max_seqlen_q=max_seqlen_in_batch_q,
|
|
|
max_seqlen_k=max_seqlen_in_batch_k,
|
|
|
dropout_p=kwargs.pop("dropout", 0.0),
|
|
|
softmax_scale=kwargs.pop("softmax_scale", None),
|
|
|
causal=causal,
|
|
|
**flash_kwargs,
|
|
|
)
|
|
|
attn_output = attn_output.view(batch_size, -1, attn_output.size(-2), attn_output.size(-1))
|
|
|
else:
|
|
|
attn_output = _flash_attention_forward(
|
|
|
query_states,
|
|
|
key_states,
|
|
|
value_states,
|
|
|
attention_mask,
|
|
|
query_length,
|
|
|
is_causal=is_causal,
|
|
|
sliding_window=sliding_window,
|
|
|
use_top_left_mask=use_top_left_mask,
|
|
|
deterministic=deterministic,
|
|
|
**kwargs,
|
|
|
)
|
|
|
|
|
|
return attn_output
|
|
|
|
|
|
|
|
|
def ulysses_flash_attn_forward(
|
|
|
self,
|
|
|
hidden_states: torch.Tensor,
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
|
**kwargs,
|
|
|
) -> Tuple[torch.Tensor, None, None]:
|
|
|
from transformers.models.qwen2_vl.modeling_qwen2_vl import apply_multimodal_rotary_pos_emb, repeat_kv
|
|
|
|
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
query_states = self.q_proj(hidden_states)
|
|
|
key_states = self.k_proj(hidden_states)
|
|
|
value_states = self.v_proj(hidden_states)
|
|
|
|
|
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
|
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
|
|
|
|
|
|
if ulysses_sp_size > 1:
|
|
|
validate_ulysses_config(self.num_heads, ulysses_sp_size)
|
|
|
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
|
query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
|
|
|
key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
|
|
|
value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)
|
|
|
|
|
|
full_q_len = query_states.size(2)
|
|
|
else:
|
|
|
full_q_len = q_len
|
|
|
|
|
|
|
|
|
if position_embeddings is None:
|
|
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
|
else:
|
|
|
cos, sin = position_embeddings
|
|
|
|
|
|
query_states, key_states = apply_multimodal_rotary_pos_emb(query_states, key_states, cos, sin, self.rope_scaling["mrope_section"])
|
|
|
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
|
|
|
|
|
|
|
|
query_states = query_states.transpose(1, 2)
|
|
|
key_states = key_states.transpose(1, 2)
|
|
|
value_states = value_states.transpose(1, 2)
|
|
|
|
|
|
if self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None and self.layer_idx >= self.config.max_window_layers:
|
|
|
sliding_window = self.config.sliding_window
|
|
|
else:
|
|
|
sliding_window = None
|
|
|
|
|
|
attn_output = flash_attention_forward(
|
|
|
query_states,
|
|
|
key_states,
|
|
|
value_states,
|
|
|
attention_mask,
|
|
|
full_q_len,
|
|
|
dropout=dropout_rate,
|
|
|
sliding_window=sliding_window,
|
|
|
is_causal=self.is_causal,
|
|
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
|
|
position_ids=position_ids,
|
|
|
)
|
|
|
if ulysses_sp_size > 1:
|
|
|
attn_output = gather_heads_scatter_seq(attn_output, head_dim=2, seq_dim=1)
|
|
|
|
|
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
|
|
attn_output = self.o_proj(attn_output)
|
|
|
return attn_output, None, None
|
|
|
|