backends/gaudi/server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py [35:182]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
)
from text_generation_server.layers import (
    TensorParallelMultiAdapterLinear,
    TensorParallelAdapterRowLinear,
    TensorParallelRowLinear,
    TensorParallelColumnLinear,
    TensorParallelEmbedding,
    SpeculativeHead,
    get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import (
    FastLayerNorm,
    FastRMSNorm,
)
from text_generation_server.layers.rotary import (
    PositionRotaryEmbedding,
)
from text_generation_server.utils.weights import UnquantizedWeight
import habana_frameworks.torch as htorch


class Starcoder2Config(PretrainedConfig):
    model_type = "starcoder2"

    def __init__(
        self,
        vocab_size=49152,
        hidden_size=3072,
        intermediate_size=12288,
        num_hidden_layers=30,
        num_attention_heads=24,
        num_key_value_heads=2,
        mlp_type="default",
        hidden_act="gelu_pytorch_tanh",
        max_position_embeddings=4096,
        initializer_range=0.018042,
        norm_type="layer_norm",
        norm_epsilon=1e-5,
        use_cache=True,
        bos_token_id=50256,
        eos_token_id=50256,
        rope_theta=10000.0,
        sliding_window=None,
        attention_dropout=0.0,
        residual_dropout=0.0,
        embedding_dropout=0.0,
        use_bias: bool = True,
        **kwargs,
    ):
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.sliding_window = sliding_window
        self.use_bias = use_bias

        # for backward compatibility
        if num_key_value_heads is None:
            num_key_value_heads = num_attention_heads

        self.num_key_value_heads = num_key_value_heads
        self.mlp_type = mlp_type
        self.hidden_act = hidden_act
        self.initializer_range = initializer_range
        self.norm_type = norm_type
        self.norm_epsilon = norm_epsilon
        self.use_cache = use_cache
        self.rope_theta = rope_theta
        self.attention_dropout = attention_dropout
        self.residual_dropout = residual_dropout
        self.embedding_dropout = embedding_dropout

        super().__init__(
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            **kwargs,
        )


def load_attention(config, prefix, weights, layer_id):
    prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
    head_size = config.hidden_size // config.num_attention_heads
    sizes = [
        head_size * config.num_attention_heads,
        head_size * config.num_key_value_heads,
        head_size * config.num_key_value_heads,
    ]
    if config.num_attention_heads != config.num_key_value_heads:
        base_layer = _load_gqa(config, prefix, weights)
    else:
        base_layer = TensorParallelColumnLinear.load_multi(
            config,
            prefixes=prefixes,
            dim=0,
            weights=weights,
            bias=config.use_bias,
        )
    return TensorParallelMultiAdapterLinear.load(
        base_layer=base_layer,
        layer_id=layer_id,
        layer_names=prefixes,
        sizes=sizes,
        process_group=weights.process_group,
    )


def _load_gqa(config, prefix: str, weights):
    assert config.hidden_size % config.num_attention_heads == 0
    assert config.num_attention_heads % weights.process_group.size() == 0

    weight = weights.get_multi_weights_col(
        prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
        dim=0,
    )

    if isinstance(weight, UnquantizedWeight):
        weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)

        head_size = config.hidden_size // config.num_attention_heads
        num_heads = config.num_attention_heads // weights.process_group.size()
        num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
        assert list(weight.weight.shape) == [
            (num_heads + 2 * num_key_value_heads) * head_size,
            config.hidden_size,
        ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"

    if config.use_bias:
        w = [
            weights.get_sharded(f"{p}.bias", dim=0)
            for p in [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
        ]
        bias = torch.cat(w, dim=0).to(dtype=weights.dtype).to(device=weights.device)
    else:
        bias = None

    return TensorParallelColumnLinear(get_linear(weight, bias=bias))


class Starcoder2Attention(torch.nn.Module):
    def __init__(
        self,
        index: int,
        prefix: str,
        config,
        weights,
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



server/text_generation_server/models/custom_modeling/flash_starcoder2_modeling.py [33:179]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
)
from text_generation_server.layers import (
    TensorParallelMultiAdapterLinear,
    TensorParallelAdapterRowLinear,
    TensorParallelRowLinear,
    TensorParallelColumnLinear,
    TensorParallelEmbedding,
    SpeculativeHead,
    get_linear,
)
from text_generation_server.layers.attention.kv_cache import get_kv_scales
from text_generation_server.layers.layernorm import (
    FastLayerNorm,
    FastRMSNorm,
)
from text_generation_server.layers.rotary import (
    PositionRotaryEmbedding,
)
from text_generation_server.utils.weights import UnquantizedWeight


class Starcoder2Config(PretrainedConfig):
    model_type = "starcoder2"

    def __init__(
        self,
        vocab_size=49152,
        hidden_size=3072,
        intermediate_size=12288,
        num_hidden_layers=30,
        num_attention_heads=24,
        num_key_value_heads=2,
        mlp_type="default",
        hidden_act="gelu_pytorch_tanh",
        max_position_embeddings=4096,
        initializer_range=0.018042,
        norm_type="layer_norm",
        norm_epsilon=1e-5,
        use_cache=True,
        bos_token_id=50256,
        eos_token_id=50256,
        rope_theta=10000.0,
        sliding_window=None,
        attention_dropout=0.0,
        residual_dropout=0.0,
        embedding_dropout=0.0,
        use_bias: bool = True,
        **kwargs,
    ):
        self.vocab_size = vocab_size
        self.max_position_embeddings = max_position_embeddings
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.sliding_window = sliding_window
        self.use_bias = use_bias

        # for backward compatibility
        if num_key_value_heads is None:
            num_key_value_heads = num_attention_heads

        self.num_key_value_heads = num_key_value_heads
        self.mlp_type = mlp_type
        self.hidden_act = hidden_act
        self.initializer_range = initializer_range
        self.norm_type = norm_type
        self.norm_epsilon = norm_epsilon
        self.use_cache = use_cache
        self.rope_theta = rope_theta
        self.attention_dropout = attention_dropout
        self.residual_dropout = residual_dropout
        self.embedding_dropout = embedding_dropout

        super().__init__(
            bos_token_id=bos_token_id,
            eos_token_id=eos_token_id,
            **kwargs,
        )


def load_attention(config, prefix, weights, layer_id):
    prefixes = [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
    head_size = config.hidden_size // config.num_attention_heads
    sizes = [
        head_size * config.num_attention_heads,
        head_size * config.num_key_value_heads,
        head_size * config.num_key_value_heads,
    ]
    if config.num_attention_heads != config.num_key_value_heads:
        base_layer = _load_gqa(config, prefix, weights)
    else:
        base_layer = TensorParallelColumnLinear.load_multi(
            config,
            prefixes=prefixes,
            dim=0,
            weights=weights,
            bias=config.use_bias,
        )
    return TensorParallelMultiAdapterLinear.load(
        base_layer=base_layer,
        layer_id=layer_id,
        layer_names=prefixes,
        sizes=sizes,
        process_group=weights.process_group,
    )


def _load_gqa(config, prefix: str, weights):
    assert config.hidden_size % config.num_attention_heads == 0
    assert config.num_attention_heads % weights.process_group.size() == 0

    weight = weights.get_multi_weights_col(
        prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
        dim=0,
    )

    if isinstance(weight, UnquantizedWeight):
        weight.weight = weight.weight.to(dtype=weights.dtype).to(device=weights.device)

        head_size = config.hidden_size // config.num_attention_heads
        num_heads = config.num_attention_heads // weights.process_group.size()
        num_key_value_heads = config.num_key_value_heads // weights.process_group.size()
        assert list(weight.weight.shape) == [
            (num_heads + 2 * num_key_value_heads) * head_size,
            config.hidden_size,
        ], f"{list(weight.weight.shape)} != {[(num_heads + 2 * config.num_key_value_heads) * head_size, config.hidden_size]}"

    if config.use_bias:
        w = [
            weights.get_sharded(f"{p}.bias", dim=0)
            for p in [f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"]
        ]
        bias = torch.cat(w, dim=0).to(dtype=weights.dtype).to(device=weights.device)
    else:
        bias = None

    return TensorParallelColumnLinear(get_linear(weight, bias=bias))


class Starcoder2Attention(torch.nn.Module):
    def __init__(
        self,
        index: int,
        prefix: str,
        config,
        weights,
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



