Apriel-H1-27_50-15b-Thinker / modeling_apriel_h.py
nitsanluke's picture
initial commit
3af8776
import copy
import math
from dataclasses import dataclass
from typing import Any, Optional, Union
import torch
import torch.nn.functional as F
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
from .configuration_apriel_h import AprielHConfig
from einops import rearrange, repeat
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
from mamba_ssm.ops.triton.selective_state_update import selective_state_update
from torch import nn
from transformers import GenerationMixin
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralMLP, MistralModel, MistralRMSNorm
from transformers.processing_utils import Unpack
from transformers.utils import LossKwargs, can_return_tuple, logging
from transformers.utils.generic import ModelOutput
logger = logging.get_logger(__name__)
is_fast_path_available = all((selective_state_update, causal_conv1d_fn, causal_conv1d_update))
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/jamba/modeling_jamba.py
class HybridMambaAttentionDynamicCache(DynamicCache):
"""
A dynamic cache that can handle both the attention cache (which has a seq_len dimension) and the mamba cache
(which has a constant shape regardless of seq_len).
This cache has two sets of lists of tensors: `key_cache` and `value_cache` for attention cache and `conv_states`
and `ssm_states` for mamba cache. Each of these lists has `num_layers` tensors. The expected shape for each tensor
For attention layers, `key_cache` and `value_cache` have a shape of `(batch_size, num_heads, seq_len, head_dim)`,
while `conv_states` and `ssm_states` have a shape of `(batch_size, 0)` (empty tensors).
For mamba layers, `key_cache` and `value_cache` have a shape of `(batch_size, 0)` (empty tensors),
while `conv_states` represents the convolution state and has a shape of `(batch_size, d_inner, d_conv)`,
and `ssm_states` represents the ssm state and has a shape of `(batch_size, d_inner, d_state)`.
"""
def __init__(self, config: AprielHConfig, batch_size, dtype=torch.float16, device=None):
super().__init__()
self.dtype = dtype
self.hybrid_override_pattern = config.hybrid_block_layout
self.has_previous_state = False # only used by mamba
intermediate_size = (
config.ssm_cfg["d_inner"]
if config.ssm_cfg["d_inner"] is not None
else config.ssm_cfg["expand"] * config.hidden_size
)
ssm_state_size = config.ssm_cfg["d_state"]
conv_kernel_size = config.ssm_cfg["d_conv"]
self.n_qk_heads = config.ssm_cfg["n_qk_heads"]
self.num_C_head = intermediate_size // ssm_state_size # mamba2
assert intermediate_size % self.n_qk_heads == 0, "d_inner must be divisible by n_qk_heads"
self.head_d = intermediate_size // self.n_qk_heads
self.conv_states = []
self.ssm_states = []
self.transformer_layers = []
for i in range(config.num_hidden_layers):
if self.hybrid_override_pattern[i] == "m2d":
# Mamba layer
self.conv_states += [
torch.zeros(
batch_size,
conv_kernel_size,
intermediate_size + 2 * self.n_qk_heads * ssm_state_size,
device=device,
dtype=dtype,
).transpose(1, 2)
]
self.ssm_states += [
torch.zeros(batch_size, self.n_qk_heads, self.head_d, ssm_state_size, device=device, dtype=dtype)
]
elif self.hybrid_override_pattern[i] == "m2":
if "repeat_kv_before_conv" in config.ssm_cfg:
assert (
config.ssm_cfg["repeat_kv_before_conv"] == True
), "Only support repeat_kv_before_conv=True for m2 for now"
self.conv_states += [
torch.zeros(
batch_size,
intermediate_size,
conv_kernel_size,
device=device,
dtype=dtype,
)
]
self.ssm_states += [
torch.zeros(
batch_size,
self.num_C_head,
intermediate_size // self.num_C_head,
ssm_state_size,
device=device,
dtype=dtype,
)
]
else:
# Attention or MLP layer
self.conv_states += [torch.tensor([[]] * batch_size, device=device)]
self.ssm_states += [torch.tensor([[]] * batch_size, device=device)]
self.transformer_layers.append(i)
self.key_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
self.value_cache = [torch.tensor([[]] * batch_size, device=device) for _ in range(config.num_hidden_layers)]
def update(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[dict[str, Any]] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
# Update the cache
if self.key_cache[layer_idx].shape[-1] == 0:
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=2)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
device = self.key_cache[layer_idx].device
self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.value_cache[layer_idx].device
self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
device = self.conv_states[layer_idx].device
self.conv_states[layer_idx] = self.conv_states[layer_idx].index_select(0, beam_idx.to(device))
device = self.ssm_states[layer_idx].device
self.ssm_states[layer_idx] = self.ssm_states[layer_idx].index_select(0, beam_idx.to(device))
def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]:
raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
@classmethod
def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
# Copied from modeling_mamba2.py
def update_conv_state(
self, layer_idx: int, new_conv_state: torch.Tensor, cache_init: bool = False
) -> torch.Tensor:
if cache_init:
self.conv_states[layer_idx] = new_conv_state.to(self.conv_states.device)
else:
self.conv_states[layer_idx] = self.conv_states[layer_idx].roll(shifts=-1, dims=-1)
self.conv_states[layer_idx][:, :, -1] = new_conv_state[:, 0, :].to(self.conv_states.device)
return self.conv_states[layer_idx]
def update_ssm_state(self, layer_idx: int, new_ssm_state: torch.Tensor):
self.ssm_states[layer_idx] = new_ssm_state.to(self.ssm_states.device)
return self.ssm_states[layer_idx]
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
# take any layer that contains cache and not empty tensor
layer_idx = self.transformer_layers[0] if layer_idx not in self.transformer_layers else layer_idx
if len(self.key_cache) <= layer_idx:
return 0
is_empty_layer = (
len(self.key_cache) == 0 # no cache in any layer
or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
or not self.key_cache[layer_idx].numel() # the layer has no cache
)
return self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
# return self.key_cache[layer_idx].shape[-2]
def reset(self):
self.conv_states.zero_()
self.ssm_states.zero_()
def to_legacy_cache(self) -> tuple[tuple[torch.Tensor], tuple[torch.Tensor]]:
raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
@classmethod
def from_legacy_cache(cls, past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None) -> "DynamicCache":
raise NotImplementedError("HybridMambaAttentionDynamicCache does not have a legacy cache equivalent.")
@dataclass
class AprielHybridCausalOutput(ModelOutput):
"""Custom output class for MambaLMHeadModel."""
loss: Optional[torch.FloatTensor] = None
logits: Optional[torch.FloatTensor] = None
all_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
last_hidden_state: Optional[torch.FloatTensor] = None
attention_weights: Optional[torch.FloatTensor] = None
past_key_values: Optional[Cache] = None
def segsum(x):
"""More stable segment sum calculation."""
# [1, 2, 3]
T = x.size(-1)
x = repeat(x, "... d -> ... d e", e=T)
# [[1, 1, 1], [2, 2, 2], [3, 3, 3]]
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=-1)
x = x.masked_fill(~mask, 0)
# [[0, 0, 0], [2, 0, 0], [3, 3, 0]]
x_segsum = torch.cumsum(x, dim=-2)
# [[0, 0, 0], [2, 0, 0], [5, 3, 0]]
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), diagonal=0)
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
return x_segsum
def materialize_mixer(A_log, B, C, D):
"""
Since the transfer matrix will be equated to the attention matrix,
we need to support the form: torch.matmul(attn_weights, value_states).
Thus, y = torch.matmul(T, X)
Arguments:
A_log: (batch, length, n_heads)
B: (batch, length, n_heads, d_state)
C: (batch, length, n_heads, d_state)
Return:
T: (batch, n_heads, length, length)
"""
batch_size, length, n_heads, d_state = B.shape
assert A_log.shape == (batch_size, length, n_heads)
assert B.shape == C.shape == (batch_size, length, n_heads, d_state)
# Compute:
A_log = rearrange(-F.softplus(A_log), "b l h -> b h l")
powers = torch.exp(segsum(A_log))
T = torch.einsum("blhn,bshn,bhls->bhsl", C, B, powers)
# Add D:
if D is not None:
T[:, :, torch.arange(length), torch.arange(length)] += D.view(1, n_heads, 1)
T = rearrange(T, "b h z l -> b h l z")
return T
def apply_mask_to_padding_states(hidden_states, attention_mask):
"""
Tunes out the hidden states for padding tokens, see https://github.com/state-spaces/mamba/issues/66
"""
if attention_mask is not None and attention_mask.shape[1] > 1 and attention_mask.shape[0] > 1:
dtype = hidden_states.dtype
hidden_states = (hidden_states * attention_mask[:, :, None]).to(dtype)
return hidden_states
class Mamba(nn.Module):
def __init__(
self,
d_model,
d_inner,
d_xb=None,
d_state=16,
d_conv=4,
expand=2,
dt_rank="auto",
dt_min=0.001,
dt_max=0.1,
dt_init="random",
dt_scale=1.0,
dt_init_floor=1e-4,
repeat_kv_before_conv=True,
conv_bias=True,
bias=False,
dt_proj_bias=True,
use_fast_path=True, # Fused kernel options
layer_idx=None,
device=None,
dtype=None,
**kwargs,
):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.d_model = d_model
self.d_xb = d_xb if d_xb is not None else d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = d_inner if d_inner is not None else int(self.expand * self.d_model)
self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
self.use_fast_path = use_fast_path
self.layer_idx = layer_idx
self.repeat_kv_before_conv = repeat_kv_before_conv
if self.repeat_kv_before_conv:
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
bias=conv_bias,
kernel_size=d_conv,
groups=self.d_inner,
padding=d_conv - 1,
**factory_kwargs,
)
else:
self.conv1d = nn.Conv1d(
in_channels=self.d_xb,
out_channels=self.d_xb,
bias=conv_bias,
kernel_size=d_conv,
groups=self.d_xb,
padding=d_conv - 1,
**factory_kwargs,
)
self.activation = "silu"
self.act = nn.SiLU()
self.num_xb_head = self.d_xb // self.d_state
self.num_C_head = self.d_inner // self.d_state
self.repeat_group = self.num_C_head // self.num_xb_head
self.in_proj = nn.Linear(self.d_model, 2 * self.d_xb + 2 * self.d_inner, bias=bias, **factory_kwargs)
self.dt_in_proj = nn.Linear(self.d_model, self.dt_rank, bias=bias, **factory_kwargs)
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=dt_proj_bias, **factory_kwargs)
# Initialize special dt projection to preserve variance at initialization
dt_init_std = self.dt_rank**-0.5 * dt_scale
if dt_init == "constant":
nn.init.constant_(self.dt_proj.weight, dt_init_std)
elif dt_init == "random":
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
else:
raise NotImplementedError
# Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
dt = torch.exp(
torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min)) + math.log(dt_min)
).clamp(min=dt_init_floor)
# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
inv_dt = dt + torch.log(-torch.expm1(-dt))
with torch.no_grad():
self.dt_proj.bias.copy_(inv_dt)
# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
self.dt_proj.bias._no_reinit = True
# S4D real initialization
A = repeat(
torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=self.d_inner,
).contiguous()
A_log = torch.log(A) # Keep A_log in fp32
self.A_log = nn.Parameter(A_log)
self.A_log._no_weight_decay = True
# D "skip" parameter
self.D = nn.Parameter(torch.ones(self.d_inner, device=device)) # Keep in fp32
self.D._no_weight_decay = True
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
def forward(
self,
hidden_states: torch.Tensor,
past_key_value: Optional[HybridMambaAttentionDynamicCache] = None,
mamba_mask: Optional[torch.Tensor] = None,
return_mixer_matrix=False,
**kwargs,
):
"""
hidden_states: (B, L, D)
Returns: same shape as hidden_states
"""
assert is_fast_path_available and "cuda" in self.in_proj.weight.device.type, "Only support fast path on cuda"
cache_position = kwargs.get("cache_position", None)
batch, seqlen, dim = hidden_states.shape
ssm_state, conv_state = None, None
use_precomputed_states = False
#########################################################
# Quick and dirty to work with CG
if "inference_params" in kwargs:
seqlen_offset = kwargs["inference_params"].seqlen_offset
if seqlen_offset > 0:
use_precomputed_states = True
else:
seqlen_offset = kwargs.get("seqlen_offset", cache_position[0]) if cache_position is not None else 0
use_precomputed_states = (
past_key_value is not None
and past_key_value.has_previous_state
and seqlen == 1
and past_key_value.conv_states[self.layer_idx].shape[0]
== past_key_value.ssm_states[self.layer_idx].shape[0]
== batch
and cache_position is not None
and seqlen_offset > 0
)
#########################################################
ssm_state, conv_state = self._get_states_from_cache(past_key_value, batch)
if use_precomputed_states:
out, _, _ = self.step(hidden_states, conv_state, ssm_state)
return {"hidden_states": out}
outputs = {}
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
zxbc = self.in_proj(hidden_states)
z, x, B, C = torch.split(
zxbc,
[
self.d_inner,
self.d_xb,
self.d_xb,
self.d_inner,
],
dim=-1,
)
x = rearrange(x, "b l d -> b d l")
z = rearrange(z, "b l d -> b d l")
B = rearrange(B, "b l (n_group dstate) -> b n_group l dstate", dstate=self.d_state)
B = repeat_kv(B, self.repeat_group) # B, n_group, L, H
B = rearrange(B, "b n_group l dstate -> b n_group dstate l").contiguous()
C = rearrange(C, "b l (n_group dstate) -> b n_group dstate l", dstate=self.d_state).contiguous()
dt = self.dt_proj(self.dt_in_proj(hidden_states)) # B, L, d_inner
dt = rearrange(dt, "b l d -> b d l") # B, d_inner, L
if self.repeat_kv_before_conv:
x = rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state)
x = repeat_kv(x, self.repeat_group)
x = rearrange(x, "b n_group l dstate -> b (n_group dstate) l")
# Compute short convolution
if conv_state is not None:
# If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
# Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
# Update state (B D W)
conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0)))
if causal_conv1d_fn is None:
x = self.act(self.conv1d(x)[..., :seqlen]).transpose(1, 2)
else:
assert self.activation in ["silu", "swish"]
x = causal_conv1d_fn(
x=x,
weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
bias=self.conv1d.bias,
activation=self.activation,
)
if not self.repeat_kv_before_conv:
x = rearrange(x, "b (n_group dstate) l -> b n_group l dstate", dstate=self.d_state)
x = repeat_kv(x, self.repeat_group)
x = rearrange(x, "b n_group l dstate -> b (n_group dstate) l")
y = selective_scan_fn(
x,
dt,
A,
B,
C,
self.D.float(),
z=z,
delta_bias=self.dt_proj.bias.float(),
delta_softplus=True,
return_last_state=(ssm_state is not None),
)
if ssm_state is not None:
y, last_state = y
ssm_state.copy_(rearrange(last_state, "b (h d) n -> b h d n", h=self.num_C_head))
y = rearrange(y, "b d l -> b l d")
out = self.out_proj(y)
outputs["hidden_states"] = out[:, :seqlen, :]
return outputs
def step(self, hidden_states, conv_state, ssm_state):
dtype = hidden_states.dtype
assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
hidden_states_input = hidden_states.squeeze(1)
A = -torch.exp(self.A_log.float()) # (d_inner, d_state)
zxbc = self.in_proj(hidden_states_input)
z, x, B, C = torch.split(zxbc, [self.d_inner, self.d_xb, self.d_xb, self.d_inner], dim=-1)
B = rearrange(B, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state)
B = torch.repeat_interleave(B, dim=1, repeats=self.repeat_group)
C = rearrange(C, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state).contiguous()
dt = self.dt_proj(self.dt_in_proj(hidden_states_input)) # B, d_inner
if self.repeat_kv_before_conv:
x = rearrange(x, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state)
x = torch.repeat_interleave(x, dim=1, repeats=self.repeat_group)
x = rearrange(x, "b n_group dstate -> b (n_group dstate)")
# Conv step
if causal_conv1d_update is None:
# Update state (B D W)
conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))
conv_state[:, :, -1] = x
x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1) # (B D)
if self.conv1d.bias is not None:
x = x + self.conv1d.bias
x = self.act(x).to(dtype=dtype)
else:
x = causal_conv1d_update(
x,
conv_state,
rearrange(self.conv1d.weight, "d 1 w -> d w"),
self.conv1d.bias,
self.activation,
)
if not self.repeat_kv_before_conv:
x = rearrange(x, "b (n_group dstate) -> b n_group dstate", dstate=self.d_state)
x = torch.repeat_interleave(x, dim=1, repeats=self.repeat_group)
x = rearrange(x, "b n_group dstate -> b (n_group dstate)")
x = rearrange(x, "b (h d) -> b h d", h=self.num_C_head)
dt = rearrange(dt, "b (h d) -> b h d", h=self.num_C_head)
A = rearrange(A, "(h d) n -> h d n", h=self.num_C_head)
D = rearrange(self.D, "(h d) -> h d", h=self.num_C_head)
z = rearrange(z, "b (h d) -> b h d", h=self.num_C_head)
dt_bias = rearrange(self.dt_proj.bias, "(h d) -> h d", h=self.num_C_head)
# SSM step
assert selective_state_update is not None
y = selective_state_update(ssm_state, x, dt, A, B, C, D, z=z, dt_bias=dt_bias, dt_softplus=True)
y = rearrange(y, "b h d -> b (h d)")
out = self.out_proj(y)
return out.unsqueeze(1), conv_state, ssm_state
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
device = self.out_proj.weight.device
conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
if self.repeat_kv_before_conv:
conv_state = torch.zeros(batch_size, self.d_inner, self.d_conv, device=device, dtype=conv_dtype)
else:
conv_state = torch.zeros(batch_size, self.d_xb, self.d_conv, device=device, dtype=conv_dtype)
ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
ssm_state = torch.zeros(
batch_size, self.num_C_head, self.d_inner // self.num_C_head, self.d_state, device=device, dtype=ssm_dtype
)
return conv_state, ssm_state
def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
"""
conv_state: (batch, d_conv, conv1d.weight.shape[0])
ssm_state: (batch, n_qk_heads, headdim, d_state)
"""
assert self.layer_idx is not None
# Get states
ssm_states = inference_params.ssm_states[self.layer_idx]
conv_states = inference_params.conv_states[self.layer_idx]
if initialize_states:
ssm_states.zero_()
conv_states.zero_()
return ssm_states, conv_states
class AprielSSMM2DecoderLayer(nn.Module):
_mixer_class = Mamba
def __init__(self, config: AprielHConfig, layer_idx: int, device=None, dtype=None, **kwargs):
super().__init__(**kwargs)
factory_kwargs = {"device": device, "dtype": dtype}
self.hidden_size = config.hidden_size
self.mixer = self._mixer_class(
d_model=config.hidden_size,
layer_idx=layer_idx,
**config.ssm_cfg,
**factory_kwargs,
)
self.mlp = MistralMLP(config)
self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
self, hidden_states: torch.Tensor, **kwargs
) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
outputs = {}
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
mixer_outputs = self.mixer(
hidden_states,
**kwargs,
)
hidden_states = mixer_outputs["hidden_states"].to(residual.dtype) + residual
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
return outputs
class AprielHybridIdentity(nn.Module):
def __init__(self, config: AprielHConfig):
super().__init__()
self.config = config
def forward(self, hidden_states: torch.Tensor, **kwargs):
return (hidden_states,)
class AprielHModel(MistralModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AprielDecoderLayer`, `AprielSSMDecoderLayer`]
Args:
config: AprielSSMHybridConfig
"""
def __init__(self, config: AprielHConfig, **kwargs):
config_copy = copy.deepcopy(config)
config_copy.num_hidden_layers = 0
super().__init__(config_copy, **kwargs)
self.config = config
blocks = []
logger.info(f"Loading hyubrid model with the following layout: {config.hybrid_block_layout}")
for layer_idx, type in enumerate(config.hybrid_block_layout):
if type == "m2":
blocks.append(AprielSSMM2DecoderLayer(config, layer_idx))
elif type == "t":
blocks.append(MistralDecoderLayer(config, layer_idx))
elif type == "i":
blocks.append(AprielHybridIdentity(config))
else:
raise ValueError(f"Invalid block type: {type}")
self.layers = nn.ModuleList(blocks)
# Initialize weights and apply final processing
self.post_init()
@can_return_tuple
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> BaseModelOutputWithPast:
use_cache = use_cache if use_cache is not None else self.config.use_cache
if use_cache and past_key_values is None:
# for the case where prepare_inputs_for_generation is not called to create the cache (as in fast-llm test)
batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
past_key_values = HybridMambaAttentionDynamicCache(self.config, batch_size, self.dtype, device=self.device)
output = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**flash_attn_kwargs,
)
past_key_values: HybridMambaAttentionDynamicCache = output.past_key_values
if past_key_values and not past_key_values.has_previous_state:
past_key_values.has_previous_state = True
return output
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
class AprielThinkerSSMHybridPreTrainedModel(PreTrainedModel):
config_class = AprielHConfig
base_model_prefix = "model"
_no_split_modules = ["MistralDecoderLayer", "AprielSSMDecoderLayer", "AprielSSMM2DecoderLayer"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_attention_backend = True
def _init_weights(self, module):
std = self.config.initializer_range
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, MistralRMSNorm):
module.weight.data.fill_(1.0)
class AprielHForCausalLM(AprielThinkerSSMHybridPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
def __init__(self, config: AprielHConfig, **kwargs):
super().__init__(config, **kwargs)
self.model = AprielHModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
output_router_logits=False,
cache_position=None,
position_ids=None,
use_cache=True,
**kwargs,
):
# Overwritten -- has a unique cache type, `HybridMambaAttentionDynamicCache`
empty_past_kv = past_key_values is None or not isinstance(past_key_values, HybridMambaAttentionDynamicCache)
# If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
# Exception 1: when passing input_embeds, input_ids may be missing entries
# Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
# Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
# (we can't check exception 3 while compiling)
if not empty_past_kv:
if inputs_embeds is not None or cache_position[-1] >= input_ids.shape[1]: # Exception 1 # Exception 3
input_ids = input_ids[:, -cache_position.shape[0] :]
elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
input_ids = input_ids[:, cache_position]
else:
past_key_values = HybridMambaAttentionDynamicCache(
self.config, input_ids.shape[0], self.dtype, device=self.device
)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if not empty_past_kv:
position_ids = position_ids[:, -input_ids.shape[1] :]
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and empty_past_kv:
model_inputs = {"inputs_embeds": inputs_embeds}
else:
model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
model_inputs.update(
{
"position_ids": position_ids,
"past_key_values": past_key_values,
"use_cache": use_cache,
"attention_mask": attention_mask,
"output_router_logits": output_router_logits,
"cache_position": cache_position,
}
)
return model_inputs
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0,
**kwargs: Unpack[KwargsForCausalLM],
) -> Union[tuple, CausalLMOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
logits_to_keep (`int` or `torch.Tensor`, *optional*):
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
This is useful when using packed tensor format (single dimension for batch and sequence length).
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, MistralForCausalLM
>>> model = MistralForCausalLM.from_pretrained("meta-mistral/Mistral-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-mistral/Mistral-2-7b-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs: BaseModelOutputWithPast = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
mamba_mask=attention_mask, # non-expended mask
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
return AprielHybridCausalOutput(
loss=loss,
logits=logits,
all_hidden_states=outputs.hidden_states,
past_key_values=outputs.past_key_values,
)
__all__ = [
"AprielHForCausalLM",
"AprielHModel",
]