File size: 1,099 Bytes
501730b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from transformers import MistralConfig
from transformers.utils import logging

logger = logging.get_logger(__name__)

ssm_config_default = {
    "d_state": 64,
    "n_qk_heads": 32,
    "expand": 1,
    "chunk_size": 128,
    "activation": "identity",
    "bias": False,
    "d_conv": 4,
    "d_inner": 32 * 128,
    "d_xb": None,  # will be set to model dim
    "dt_rank": "auto",
    "dt_min": 0.001,
    "dt_max": 0.1,
    "dt_init": "random",
    "dt_scale": 1.0,
    "dt_init_floor": 1e-4,
    "conv_bias": True,
}


class AprielHConfig(MistralConfig):
    model_type = "apriel_h"

    def __init__(self, hybrid_block_layout=["m2"], ssm_cfg=None, **kwargs):
        super().__init__(**kwargs)
        self.hybrid_block_layout = hybrid_block_layout
        self.head_dim = self.head_dim or self.hidden_size // self.num_attention_heads  # as in transformers 4.51.3
        self.ssm_cfg = ssm_cfg or ssm_config_default

        for k, v in ssm_config_default.items():
            if k not in self.ssm_cfg:
                self.ssm_cfg[k] = v  # to make sure all elements are present in the config