File size: 7,253 Bytes
bcdf9fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Apply monkey-patch function to models
"""
import importlib.metadata
import sys
from functools import lru_cache
from typing import Optional
import torch
from packaging import version
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from transformers.modeling_utils import PreTrainedModel
from verl.utils.ulysses import (
gather_heads_scatter_seq,
gather_seq_scatter_heads,
get_ulysses_sequence_parallel_group,
get_ulysses_sequence_parallel_world_size,
)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=2, repeats=n_rep). The hidden states go from (batch,
seqlen, num_key_value_heads, head_dim) to (batch, seqlen, num_attention_heads, head_dim)
"""
batch, slen, num_key_value_heads, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, :, None, :].expand(batch, slen, num_key_value_heads, n_rep, head_dim)
return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep, head_dim)
def _ulysses_flash_attention_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
value_states: torch.Tensor,
*args,
position_ids: Optional[torch.Tensor] = None,
**kwargs,
):
"""Insert all-to-all before and after flash attention.
DeepSpeed-Ulysses: https://arxiv.org/pdf/2309.14509
Args:
query_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads, head_dim)
key_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads_k, head_dim)
value_states (torch.Tensor): (batch_size, seqlen/sp_size, nheads_k, head_dim)
position_ids (torch.Tensor, optional): (batch_size, seqlen/sp_size)
Returns:
torch.Tensor: (batch_size, seqlen/sp_size, nheads, head_dim)
"""
ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
########## AlltoAll for Ulysses ##########
if ulysses_sp_size > 1:
assert position_ids is not None, "position_ids is required for Ulysses sequence parallelism"
# NOTE: repeat kv heads to be divided by sequence parallel. Instead of repeating nheads_q//nheads_k,
# we choose to repeat sp_size//nheads_k, since flash_attention supports MQA/GQA.
# For example:
# - nheads_k=4, sp=8, repeats=2
# - nheads_k=8, sp=8, repeats=1
# - nheads_k=16, sp=8, repeats=1
repeats = max(ulysses_sp_size // key_states.size(2), 1)
key_states = repeat_kv(key_states, repeats)
value_states = repeat_kv(value_states, repeats)
# (bsz, seq_len/n, n_head, head_dim) -> (bsz, seq_len, n_head/n, head_dim)
query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2)
key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2)
value_states = gather_seq_scatter_heads(value_states, seq_dim=1, head_dim=2)
# TODO: all_gather position_ids because `prepare_fa2_from_position_ids` needs it, we can eliminate
# this all_gather by passing cu_seq_lens_q, cu_seq_lens_k, max_length_k, max_length_q explicitly.
# https://github.com/huggingface/transformers/pull/33932
# (bsz, seq_len/n) -> (bsz, seq_len)
position_ids_list = [torch.empty_like(position_ids) for _ in range(ulysses_sp_size)]
torch.distributed.all_gather(position_ids_list, position_ids, group=get_ulysses_sequence_parallel_group())
position_ids = torch.concat(position_ids_list, dim=-1)
# (bsz, seq_len, n_head/n, head_dim)
attn_output = _flash_attention_forward(query_states, key_states, value_states, *args, position_ids=position_ids, **kwargs)
########## AlltoAll for Ulysses ##########
if ulysses_sp_size > 1:
# (bsz, seq_len, n_head/n, head_dim) -> (bsz, seq_len/n, n_head, head_dim)
attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2)
return attn_output
def apply_monkey_patch(model: PreTrainedModel, ulysses_sp_size: int):
"""Replace _flash_attention_forward to _ulysses_flash_attention_forward"""
module = sys.modules[model.__module__]
num_attention_heads, num_key_value_heads = model.config.num_attention_heads, model.config.num_key_value_heads
assert num_attention_heads % ulysses_sp_size == 0, f"num_attention_heads {num_attention_heads} must be divisible by ulysses_sp_size {ulysses_sp_size}"
assert num_key_value_heads % ulysses_sp_size == 0 or ulysses_sp_size % num_key_value_heads == 0, (
f"num_key_value_heads {num_key_value_heads} must be divisible by ulysses_sp_size {ulysses_sp_size}or vise versa. Upon ulysses_sp_size % num_key_value_heads == 0,kv heads are repeated to ensure correctness."
)
# TODO: VLM models only, unify monkey patch to LLM models.
if model.config.model_type in ("qwen2_vl", "qwen2_5_vl"): # patch remove padding for qwen2vl mrope
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLFlashAttention2
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2
from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward
Qwen2VLFlashAttention2.forward = ulysses_flash_attn_forward
Qwen2_5_VLFlashAttention2.forward = ulysses_flash_attn_forward
print("Monkey patch FlashAttention2.forward in Qwen2VL")
return
# transformers<=4.47.1
if hasattr(module, "_flash_attention_forward"):
module._flash_attention_forward = _ulysses_flash_attention_forward
print(f"Monkey patch _flash_attention_forward in {model.__module__}")
else:
# transformers>=4.48.0
from transformers.integrations import flash_attention
flash_attention._flash_attention_forward = _ulysses_flash_attention_forward
print(f"Monkey patch _flash_attention_forward in {flash_attention.__name__}")
@lru_cache
def is_transformers_version_in_range(min_version: str, max_version: str) -> bool:
try:
# Get the installed version of the transformers library
transformers_version = importlib.metadata.version("transformers")
except importlib.metadata.PackageNotFoundError as e:
raise ModuleNotFoundError("The `transformers` package is not installed.") from e
# Check if the version is within the specified range
return version.parse(min_version) <= version.parse(transformers_version) <= version.parse(max_version)
|