|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Callable, Optional, Tuple
|
|
|
|
|
|
import torch
|
|
|
from transformers.cache_utils import Cache
|
|
|
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
|
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
|
|
from transformers.utils import logging
|
|
|
|
|
|
from verl.utils.ulysses import (
|
|
|
gather_heads_scatter_seq,
|
|
|
gather_seq_scatter_heads,
|
|
|
get_ulysses_sequence_parallel_world_size,
|
|
|
validate_ulysses_config,
|
|
|
)
|
|
|
|
|
|
logger = logging.get_logger(__name__)
|
|
|
|
|
|
|
|
|
def qwen2_flash_attn_forward(
|
|
|
self,
|
|
|
hidden_states: torch.Tensor,
|
|
|
attention_mask: Optional[torch.Tensor] = None,
|
|
|
position_ids: Optional[torch.LongTensor] = None,
|
|
|
past_key_value: Optional[Cache] = None,
|
|
|
output_attentions: bool = False,
|
|
|
use_cache: bool = False,
|
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
|
):
|
|
|
"""
|
|
|
Adapted from transformers 4.47.1 to support Ulysses sequence parallelism.
|
|
|
|
|
|
NOTE: This function is only tested on transformers versions between 4.45.0 and 4.47.1.
|
|
|
"""
|
|
|
bsz, q_len, _ = hidden_states.size()
|
|
|
|
|
|
query_states = self.q_proj(hidden_states)
|
|
|
key_states = self.k_proj(hidden_states)
|
|
|
value_states = self.v_proj(hidden_states)
|
|
|
|
|
|
query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
|
key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
|
value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
|
|
|
|
|
|
|
|
|
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
|
|
|
|
|
|
if ulysses_sp_size > 1:
|
|
|
validate_ulysses_config(self.num_heads, ulysses_sp_size)
|
|
|
|
|
|
|
|
|
query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
|
|
|
key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
|
|
|
value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)
|
|
|
|
|
|
full_q_len = query_states.size(2)
|
|
|
|
|
|
if position_embeddings is None:
|
|
|
logger.warning_once(
|
|
|
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
|
|
|
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
|
|
|
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
|
|
|
"removed and `position_embeddings` will be mandatory."
|
|
|
)
|
|
|
cos, sin = self.rotary_emb(value_states, position_ids)
|
|
|
else:
|
|
|
cos, sin = position_embeddings
|
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
|
|
|
if past_key_value is not None:
|
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
|
|
|
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
|
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
|
|
dropout_rate = 0.0 if not self.training else self.attention_dropout
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_dtype = query_states.dtype
|
|
|
if input_dtype == torch.float32:
|
|
|
if torch.is_autocast_enabled():
|
|
|
target_dtype = torch.get_autocast_gpu_dtype()
|
|
|
|
|
|
elif hasattr(self.config, "_pre_quantization_dtype"):
|
|
|
target_dtype = self.config._pre_quantization_dtype
|
|
|
else:
|
|
|
target_dtype = self.q_proj.weight.dtype
|
|
|
|
|
|
logger.warning_once(f"The input hidden states seems to be silently casted in float32, this might be related to the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in {target_dtype}.")
|
|
|
|
|
|
query_states = query_states.to(target_dtype)
|
|
|
key_states = key_states.to(target_dtype)
|
|
|
value_states = value_states.to(target_dtype)
|
|
|
|
|
|
|
|
|
query_states = query_states.transpose(1, 2)
|
|
|
key_states = key_states.transpose(1, 2)
|
|
|
value_states = value_states.transpose(1, 2)
|
|
|
|
|
|
if self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None and self.layer_idx >= self.config.max_window_layers:
|
|
|
sliding_window = self.config.sliding_window
|
|
|
else:
|
|
|
sliding_window = None
|
|
|
|
|
|
attn_output = _flash_attention_forward(
|
|
|
query_states,
|
|
|
key_states,
|
|
|
value_states,
|
|
|
attention_mask,
|
|
|
full_q_len,
|
|
|
position_ids=position_ids,
|
|
|
dropout=dropout_rate,
|
|
|
sliding_window=sliding_window,
|
|
|
is_causal=self.is_causal,
|
|
|
use_top_left_mask=self._flash_attn_uses_top_left_mask,
|
|
|
)
|
|
|
|
|
|
|
|
|
attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()
|
|
|
|
|
|
if ulysses_sp_size > 1:
|
|
|
attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)
|
|
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
|
|
attn_output = self.o_proj(attn_output)
|
|
|
|
|
|
if not output_attentions:
|
|
|
attn_weights = None
|
|
|
|
|
|
return attn_output, attn_weights, past_key_value
|
|
|
|
|
|
|
|
|
def qwen2_attn_forward(
|
|
|
self,
|
|
|
hidden_states: torch.Tensor,
|
|
|
position_embeddings: Tuple[torch.Tensor, torch.Tensor],
|
|
|
attention_mask: Optional[torch.Tensor],
|
|
|
past_key_value: Optional[Cache] = None,
|
|
|
cache_position: Optional[torch.LongTensor] = None,
|
|
|
**kwargs,
|
|
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
|
|
"""
|
|
|
Adapted from transformers 4.49.0 to support Ulysses sequence parallelism for transformers >= 4.48.0.
|
|
|
|
|
|
NOTE: This function has been tested only on transformers versions between 4.48.0 and 4.50.0.
|
|
|
"""
|
|
|
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
|
|
|
|
|
bsz, q_len, _ = hidden_states.shape
|
|
|
hidden_shape = (bsz, q_len, -1, self.head_dim)
|
|
|
|
|
|
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
|
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
|
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
|
|
|
|
|
|
|
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
|
|
|
|
|
|
if ulysses_sp_size > 1:
|
|
|
validate_ulysses_config(self.config.num_attention_heads, ulysses_sp_size)
|
|
|
|
|
|
|
|
|
query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1)
|
|
|
key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1)
|
|
|
value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1)
|
|
|
|
|
|
full_q_len = query_states.size(2)
|
|
|
|
|
|
cos, sin = position_embeddings
|
|
|
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
|
|
|
|
|
|
if past_key_value is not None:
|
|
|
|
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
|
|
|
|
sliding_window = None
|
|
|
if self.config.use_sliding_window and getattr(self.config, "sliding_window", None) is not None and self.layer_idx >= self.config.max_window_layers:
|
|
|
sliding_window = self.config.sliding_window
|
|
|
|
|
|
from transformers.models.qwen2.modeling_qwen2 import eager_attention_forward
|
|
|
|
|
|
attention_interface: Callable = eager_attention_forward
|
|
|
if self.config._attn_implementation != "eager":
|
|
|
if self.config._attn_implementation == "sdpa" and kwargs.get("output_attentions", False):
|
|
|
logger.warning_once('`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.')
|
|
|
else:
|
|
|
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
|
|
|
|
|
attn_output, attn_weights = attention_interface(
|
|
|
self,
|
|
|
query_states,
|
|
|
key_states,
|
|
|
value_states,
|
|
|
attention_mask,
|
|
|
dropout=0.0 if not self.training else self.attention_dropout,
|
|
|
scaling=self.scaling,
|
|
|
sliding_window=sliding_window,
|
|
|
**kwargs,
|
|
|
)
|
|
|
|
|
|
attn_output = attn_output.reshape(bsz, full_q_len, -1, self.head_dim).contiguous()
|
|
|
|
|
|
if ulysses_sp_size > 1:
|
|
|
|
|
|
attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)
|
|
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
|
|
|
attn_output = self.o_proj(attn_output)
|
|
|
return attn_output, attn_weights
|
|
|
|