def __init__()

in backends/gaudi/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py [0:0]


    def __init__(self, prefix, config, weights):
        super().__init__()

        process_group = weights.process_group
        self.tp_rank = process_group.rank()
        self.tp_world_size = process_group.size()

        # Skip fp8 quant for first and last layers
        self.layers = nn.ModuleList()
        self.cross_attention_layers = getattr(config, "cross_attention_layers", [])

        rotary_emb = PositionRotaryEmbedding.static(
            config=config,
            dim=config.hidden_size // config.num_attention_heads,
            base=config.rope_theta,
            device=weights.device,
        )
        with no_fp8(weights):
            self.layers.append(
                FlashLlamaLayer(
                    index=0,
                    prefix=f"{prefix}.layers.0",
                    config=config,
                    weights=weights,
                    rotary_emb=rotary_emb,
                )
            )

        # Skip first and last layers
        for layer_id in range(1, config.num_hidden_layers - 1):
            if layer_id in self.cross_attention_layers:
                from text_generation_server.models.custom_modeling.flash_mllama import (
                    FlashLlamaCrossLayer,
                )

                self.layers.append(
                    FlashLlamaCrossLayer(
                        index=layer_id,
                        prefix=(f"{prefix}.layers.{layer_id}"),
                        config=config,
                        weights=weights,
                    )
                )
            else:
                self.layers.append(
                    FlashLlamaLayer(
                        index=layer_id,
                        prefix=(f"{prefix}.layers.{layer_id}"),
                        config=config,
                        weights=weights,
                        rotary_emb=rotary_emb,
                    )
                )

        with no_fp8(weights):
            last_layer_id = config.num_hidden_layers - 1
            self.layers.append(
                FlashLlamaLayer(
                    index=last_layer_id,
                    prefix=(f"{prefix}.layers.{last_layer_id}"),
                    config=config,
                    weights=weights,
                    rotary_emb=rotary_emb,
                )
            )

        self.norm = FastRMSNorm.load(
            prefix=f"{prefix}.norm",
            weights=weights,
            eps=config.rms_norm_eps,
        )

        self.gradient_checkpointing = False

        self.head_size = self.layers[0].self_attn.head_size
        self.num_heads = self.layers[0].self_attn.num_heads
        self.num_key_value_heads = self.layers[0].self_attn.num_key_value_heads