backends/gaudi/server/text_generation_server/models/custom_modeling/flash_mllama.py [347:693]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        return hidden_states, encoder_states


class MllamaPrecomputedAspectRatioEmbedding(nn.Module):
    def __init__(self, *, prefix, config, weights):
        super().__init__()
        self.max_num_tiles = config.max_num_tiles
        self.hidden_size = config.hidden_size
        self.max_aspect_ratio_id = config.max_aspect_ratio_id

        self.embedding = TensorParallelEmbedding(
            prefix=f"{prefix}.embedding", weights=weights
        )
        self.gate = nn.Parameter(
            weights.get_tensor(f"{prefix}.gate"), requires_grad=False
        )

    def forward(
        self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
    ) -> torch.Tensor:
        embeddings = self.embedding(aspect_ratio_ids)
        embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size)

        # Always gated.
        embeddings = embeddings * self.gate.tanh()

        hidden_state = hidden_state + embeddings
        return hidden_state


class MllamaPrecomputedPositionEmbedding(nn.Module):
    def __init__(self, *, prefix, config, weights):
        super().__init__()
        self.max_num_tiles = config.max_num_tiles
        self.max_aspect_ratio_id = config.max_aspect_ratio_id
        self.num_patches = (config.image_size // config.patch_size) ** 2 + 1
        self.hidden_size = config.hidden_size
        self.scale = config.hidden_size**-0.5

        self.gate = nn.Parameter(
            weights.get_tensor(f"{prefix}.gate"), requires_grad=False
        )

        # position embedding
        embedding = nn.Parameter(
            weights.get_tensor(f"{prefix}.embedding"), requires_grad=False
        )
        self.gated_position_embedding = (1 - self.gate.tanh()) * embedding
        self.tile_embedding = TensorParallelEmbedding(
            prefix=f"{prefix}.tile_embedding", weights=weights
        )

    def forward(
        self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
    ) -> torch.Tensor:
        # position embeddings
        hidden_state = hidden_state + self.gated_position_embedding.view(
            1, 1, self.num_patches, self.hidden_size
        )

        # precomputed tile position embeddings
        tile_position_embedding = self.tile_embedding(aspect_ratio_ids)
        batch_size = hidden_state.shape[0]
        tile_position_embedding = tile_position_embedding.reshape(
            batch_size, self.max_num_tiles, self.num_patches, self.hidden_size
        )
        gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding
        hidden_state = hidden_state + gated_tile_position_embedding

        return hidden_state


class MllamaVisionModel(nn.Module):
    def __init__(self, *, prefix, config, weights):
        super().__init__()
        self.image_size = config.image_size
        self.patch_size = config.patch_size
        self.max_num_tiles = config.max_num_tiles
        self.hidden_size = config.hidden_size
        self.num_channels = config.num_channels
        self.intermediate_layers_indices = config.intermediate_layers_indices

        self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
        self.scale = config.hidden_size**-0.5
        self.dtype = weights.dtype

        self.patch_embedding = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=self.hidden_size,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            padding="valid",
            bias=False,
        )
        self.patch_embedding.weight = nn.Parameter(
            weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
        )

        self.class_embedding = nn.Parameter(
            weights.get_tensor(f"{prefix}.class_embedding"), requires_grad=False
        )

        self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(
            prefix=f"{prefix}.gated_positional_embedding",
            config=config,
            weights=weights,
        )

        self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
            prefix=f"{prefix}.pre_tile_positional_embedding",
            config=config,
            weights=weights,
        )
        self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
            prefix=f"{prefix}.post_tile_positional_embedding",
            config=config,
            weights=weights,
        )

        ## layer norms
        self.layernorm_pre = nn.LayerNorm.load(
            prefix=f"{prefix}.layernorm_pre",
            weights=weights,
            # torch default
            eps=1e-05,
        )
        self.layernorm_post = nn.LayerNorm.load(
            prefix=f"{prefix}.layernorm_post",
            weights=weights,
            # torch default
            eps=1e-05,
        )

        ## encoders
        self.transformer = MllamaVisionEncoder(
            prefix=f"{prefix}.transformer",
            config=config,
            weights=weights,
            is_gated=False,
            num_layers=config.num_hidden_layers,
        )
        self.global_transformer = MllamaVisionEncoder(
            prefix=f"{prefix}.global_transformer",
            config=config,
            weights=weights,
            is_gated=True,
            num_layers=config.num_global_layers,
        )

    def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
        batch_size, _, hidden_size = hidden_state.shape
        class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
        hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
        return hidden_state

    def forward(
        self,
        pixel_values: torch.Tensor,
        aspect_ratio_ids: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> torch.Tensor:
        (
            batch_size,
            num_concurrent_media,
            num_tiles,
            num_channels,
            height,
            width,
        ) = pixel_values.shape

        pixel_values = pixel_values.reshape(
            batch_size * num_concurrent_media * num_tiles, num_channels, height, width
        )
        aspect_ratio_ids = aspect_ratio_ids.reshape(
            batch_size * num_concurrent_media, -1
        )

        # patch embedding
        patch_embeds = self.patch_embedding(pixel_values)
        hidden_state = patch_embeds.flatten(2).transpose(1, 2)

        # tile embeddings
        _, num_patches, dim = hidden_state.shape
        hidden_state = hidden_state.reshape(
            batch_size * num_concurrent_media, num_tiles, -1, dim
        )
        hidden_state = self.pre_tile_positional_embedding(
            hidden_state, aspect_ratio_ids
        )

        # apply cls token
        hidden_state = hidden_state.reshape(
            batch_size * num_concurrent_media * num_tiles, num_patches, dim
        )
        hidden_state = self.apply_class_embedding(hidden_state)
        num_patches += 1

        # apply position embeddings
        hidden_state = hidden_state.reshape(
            batch_size * num_concurrent_media, num_tiles, num_patches, dim
        )
        hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids)

        # apply encoder
        hidden_state = self.layernorm_pre(hidden_state)

        # Compute the number of tokens to pad
        num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
        # Compute padding tuple for pad function
        padding = (
            0,
            0,
            0,
            num_padding_patches,
        )  # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
        # Pad the tensor
        hidden_state = F.pad(hidden_state, padding, mode="constant", value=0)
        slice_index = -num_padding_patches if num_padding_patches > 0 else None

        if attention_mask is not None:
            attention_mask = attention_mask.reshape(
                batch_size * num_concurrent_media, -1
            )
            attention_mask = _prepare_aspect_ratio_attention_mask(
                aspect_ratio_mask=attention_mask,
                num_patches=self.num_patches,
                target_length=hidden_state.shape[2],
                dtype=self.dtype,
            )

        hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim)
        hidden_state, all_intermediate_hidden_states = self.transformer(
            hidden_state,
            attention_mask=attention_mask,
        )
        intermediate_hidden_states = [
            hidden_state
            for idx, hidden_state in enumerate(all_intermediate_hidden_states)
            if idx in self.intermediate_layers_indices
        ]
        intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1)

        # apply global encoder
        hidden_state = self.layernorm_post(hidden_state)
        hidden_state = hidden_state.reshape(
            batch_size * num_concurrent_media,
            num_tiles,
            num_patches + num_padding_patches,
            dim,
        )
        hidden_state = self.post_tile_positional_embedding(
            hidden_state, aspect_ratio_ids
        )
        hidden_state = hidden_state.reshape(
            batch_size * num_concurrent_media,
            num_tiles * (num_patches + num_padding_patches),
            dim,
        )
        hidden_state, _ = self.global_transformer(
            hidden_state, attention_mask=attention_mask
        )
        hidden_state = hidden_state.reshape(
            batch_size * num_concurrent_media,
            num_tiles,
            num_patches + num_padding_patches,
            dim,
        )
        hidden_state = hidden_state[:, :, :slice_index]

        # adding intermediate layer outputs
        hidden_state = hidden_state.reshape(
            batch_size, num_concurrent_media, num_tiles, num_patches, dim
        )
        intermediate_hidden_states = intermediate_hidden_states.reshape(
            batch_size * num_concurrent_media,
            num_tiles,
            num_patches + num_padding_patches,
            -1,
        )
        intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index]
        intermediate_hidden_states = intermediate_hidden_states.reshape(
            batch_size, num_concurrent_media, num_tiles, num_patches, -1
        )
        hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1)
        return hidden_state


class MllamaTextCrossAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, *, prefix, config, weights, layer_idx):
        super().__init__()
        self.config = config
        self.num_heads = self.config.num_attention_heads
        self.num_key_value_heads = self.config.num_key_value_heads
        self.dropout = config.dropout
        self.hidden_size = config.hidden_size
        self.head_size = config.hidden_size // self.num_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.layer_idx = layer_idx

        self.num_heads = self.num_heads // weights.process_group.size()
        self.num_key_value_heads = (
            self.num_key_value_heads // weights.process_group.size()
        )

        self.q_proj = TensorParallelColumnLinear.load(
            config,
            prefix=f"{prefix}.q_proj",
            weights=weights,
            bias=False,
        )
        self.k_proj = TensorParallelColumnLinear.load(
            config,
            prefix=f"{prefix}.k_proj",
            weights=weights,
            bias=False,
        )
        self.v_proj = TensorParallelColumnLinear.load(
            config,
            prefix=f"{prefix}.v_proj",
            weights=weights,
            bias=False,
        )
        self.o_proj = TensorParallelRowLinear.load(
            config,
            prefix=f"{prefix}.o_proj",
            weights=weights,
            bias=False,
        )

        self.q_norm = MllamaTextRMSNorm.load(
            prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps
        )
        self.k_norm = MllamaTextRMSNorm.load(
            prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps
        )
        self.softmax_scale = self.head_size**-0.5

    def forward(
        self,
        hidden_states: torch.Tensor,
        cross_attention_states: Optional[torch.Tensor] = None,
        # past_key_value=None,
        # attention_mask: Optional[torch.Tensor] = None,
        # cache_position: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



server/text_generation_server/models/custom_modeling/mllama.py [335:681]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        return hidden_states, encoder_states


class MllamaPrecomputedAspectRatioEmbedding(nn.Module):
    def __init__(self, *, prefix, config, weights):
        super().__init__()
        self.max_num_tiles = config.max_num_tiles
        self.hidden_size = config.hidden_size
        self.max_aspect_ratio_id = config.max_aspect_ratio_id

        self.embedding = TensorParallelEmbedding(
            prefix=f"{prefix}.embedding", weights=weights
        )
        self.gate = nn.Parameter(
            weights.get_tensor(f"{prefix}.gate"), requires_grad=False
        )

    def forward(
        self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
    ) -> torch.Tensor:
        embeddings = self.embedding(aspect_ratio_ids)
        embeddings = embeddings.reshape(-1, self.max_num_tiles, 1, self.hidden_size)

        # Always gated.
        embeddings = embeddings * self.gate.tanh()

        hidden_state = hidden_state + embeddings
        return hidden_state


class MllamaPrecomputedPositionEmbedding(nn.Module):
    def __init__(self, *, prefix, config, weights):
        super().__init__()
        self.max_num_tiles = config.max_num_tiles
        self.max_aspect_ratio_id = config.max_aspect_ratio_id
        self.num_patches = (config.image_size // config.patch_size) ** 2 + 1
        self.hidden_size = config.hidden_size
        self.scale = config.hidden_size**-0.5

        self.gate = nn.Parameter(
            weights.get_tensor(f"{prefix}.gate"), requires_grad=False
        )

        # position embedding
        embedding = nn.Parameter(
            weights.get_tensor(f"{prefix}.embedding"), requires_grad=False
        )
        self.gated_position_embedding = (1 - self.gate.tanh()) * embedding
        self.tile_embedding = TensorParallelEmbedding(
            prefix=f"{prefix}.tile_embedding", weights=weights
        )

    def forward(
        self, hidden_state: torch.Tensor, aspect_ratio_ids: torch.Tensor
    ) -> torch.Tensor:
        # position embeddings
        hidden_state = hidden_state + self.gated_position_embedding.view(
            1, 1, self.num_patches, self.hidden_size
        )

        # precomputed tile position embeddings
        tile_position_embedding = self.tile_embedding(aspect_ratio_ids)
        batch_size = hidden_state.shape[0]
        tile_position_embedding = tile_position_embedding.reshape(
            batch_size, self.max_num_tiles, self.num_patches, self.hidden_size
        )
        gated_tile_position_embedding = self.gate.tanh() * tile_position_embedding
        hidden_state = hidden_state + gated_tile_position_embedding

        return hidden_state


class MllamaVisionModel(nn.Module):
    def __init__(self, *, prefix, config, weights):
        super().__init__()
        self.image_size = config.image_size
        self.patch_size = config.patch_size
        self.max_num_tiles = config.max_num_tiles
        self.hidden_size = config.hidden_size
        self.num_channels = config.num_channels
        self.intermediate_layers_indices = config.intermediate_layers_indices

        self.num_patches = (self.image_size // self.patch_size) ** 2 + 1
        self.scale = config.hidden_size**-0.5
        self.dtype = weights.dtype

        self.patch_embedding = nn.Conv2d(
            in_channels=config.num_channels,
            out_channels=self.hidden_size,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            padding="valid",
            bias=False,
        )
        self.patch_embedding.weight = nn.Parameter(
            weights.get_tensor(f"{prefix}.patch_embedding.weight"), requires_grad=False
        )

        self.class_embedding = nn.Parameter(
            weights.get_tensor(f"{prefix}.class_embedding"), requires_grad=False
        )

        self.gated_positional_embedding = MllamaPrecomputedPositionEmbedding(
            prefix=f"{prefix}.gated_positional_embedding",
            config=config,
            weights=weights,
        )

        self.pre_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
            prefix=f"{prefix}.pre_tile_positional_embedding",
            config=config,
            weights=weights,
        )
        self.post_tile_positional_embedding = MllamaPrecomputedAspectRatioEmbedding(
            prefix=f"{prefix}.post_tile_positional_embedding",
            config=config,
            weights=weights,
        )

        ## layer norms
        self.layernorm_pre = nn.LayerNorm.load(
            prefix=f"{prefix}.layernorm_pre",
            weights=weights,
            # torch default
            eps=1e-05,
        )
        self.layernorm_post = nn.LayerNorm.load(
            prefix=f"{prefix}.layernorm_post",
            weights=weights,
            # torch default
            eps=1e-05,
        )

        ## encoders
        self.transformer = MllamaVisionEncoder(
            prefix=f"{prefix}.transformer",
            config=config,
            weights=weights,
            is_gated=False,
            num_layers=config.num_hidden_layers,
        )
        self.global_transformer = MllamaVisionEncoder(
            prefix=f"{prefix}.global_transformer",
            config=config,
            weights=weights,
            is_gated=True,
            num_layers=config.num_global_layers,
        )

    def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
        batch_size, _, hidden_size = hidden_state.shape
        class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
        hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
        return hidden_state

    def forward(
        self,
        pixel_values: torch.Tensor,
        aspect_ratio_ids: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> torch.Tensor:
        (
            batch_size,
            num_concurrent_media,
            num_tiles,
            num_channels,
            height,
            width,
        ) = pixel_values.shape

        pixel_values = pixel_values.reshape(
            batch_size * num_concurrent_media * num_tiles, num_channels, height, width
        )
        aspect_ratio_ids = aspect_ratio_ids.reshape(
            batch_size * num_concurrent_media, -1
        )

        # patch embedding
        patch_embeds = self.patch_embedding(pixel_values)
        hidden_state = patch_embeds.flatten(2).transpose(1, 2)

        # tile embeddings
        _, num_patches, dim = hidden_state.shape
        hidden_state = hidden_state.reshape(
            batch_size * num_concurrent_media, num_tiles, -1, dim
        )
        hidden_state = self.pre_tile_positional_embedding(
            hidden_state, aspect_ratio_ids
        )

        # apply cls token
        hidden_state = hidden_state.reshape(
            batch_size * num_concurrent_media * num_tiles, num_patches, dim
        )
        hidden_state = self.apply_class_embedding(hidden_state)
        num_patches += 1

        # apply position embeddings
        hidden_state = hidden_state.reshape(
            batch_size * num_concurrent_media, num_tiles, num_patches, dim
        )
        hidden_state = self.gated_positional_embedding(hidden_state, aspect_ratio_ids)

        # apply encoder
        hidden_state = self.layernorm_pre(hidden_state)

        # Compute the number of tokens to pad
        num_padding_patches = (8 - (hidden_state.shape[-2] % 8)) % 8
        # Compute padding tuple for pad function
        padding = (
            0,
            0,
            0,
            num_padding_patches,
        )  # (pad_left, pad_right, pad_left for dim -2, pad_right for dim -2)
        # Pad the tensor
        hidden_state = F.pad(hidden_state, padding, mode="constant", value=0)
        slice_index = -num_padding_patches if num_padding_patches > 0 else None

        if attention_mask is not None:
            attention_mask = attention_mask.reshape(
                batch_size * num_concurrent_media, -1
            )
            attention_mask = _prepare_aspect_ratio_attention_mask(
                aspect_ratio_mask=attention_mask,
                num_patches=self.num_patches,
                target_length=hidden_state.shape[2],
                dtype=self.dtype,
            )

        hidden_state = hidden_state.view(batch_size * num_concurrent_media, -1, dim)
        hidden_state, all_intermediate_hidden_states = self.transformer(
            hidden_state,
            attention_mask=attention_mask,
        )
        intermediate_hidden_states = [
            hidden_state
            for idx, hidden_state in enumerate(all_intermediate_hidden_states)
            if idx in self.intermediate_layers_indices
        ]
        intermediate_hidden_states = torch.stack(intermediate_hidden_states, dim=-1)

        # apply global encoder
        hidden_state = self.layernorm_post(hidden_state)
        hidden_state = hidden_state.reshape(
            batch_size * num_concurrent_media,
            num_tiles,
            num_patches + num_padding_patches,
            dim,
        )
        hidden_state = self.post_tile_positional_embedding(
            hidden_state, aspect_ratio_ids
        )
        hidden_state = hidden_state.reshape(
            batch_size * num_concurrent_media,
            num_tiles * (num_patches + num_padding_patches),
            dim,
        )
        hidden_state, _ = self.global_transformer(
            hidden_state, attention_mask=attention_mask
        )
        hidden_state = hidden_state.reshape(
            batch_size * num_concurrent_media,
            num_tiles,
            num_patches + num_padding_patches,
            dim,
        )
        hidden_state = hidden_state[:, :, :slice_index]

        # adding intermediate layer outputs
        hidden_state = hidden_state.reshape(
            batch_size, num_concurrent_media, num_tiles, num_patches, dim
        )
        intermediate_hidden_states = intermediate_hidden_states.reshape(
            batch_size * num_concurrent_media,
            num_tiles,
            num_patches + num_padding_patches,
            -1,
        )
        intermediate_hidden_states = intermediate_hidden_states[:, :, :slice_index]
        intermediate_hidden_states = intermediate_hidden_states.reshape(
            batch_size, num_concurrent_media, num_tiles, num_patches, -1
        )
        hidden_state = torch.cat([hidden_state, intermediate_hidden_states], dim=-1)
        return hidden_state


class MllamaTextCrossAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, *, prefix, config, weights, layer_idx):
        super().__init__()
        self.config = config
        self.num_heads = self.config.num_attention_heads
        self.num_key_value_heads = self.config.num_key_value_heads
        self.dropout = config.dropout
        self.hidden_size = config.hidden_size
        self.head_size = config.hidden_size // self.num_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.layer_idx = layer_idx

        self.num_heads = self.num_heads // weights.process_group.size()
        self.num_key_value_heads = (
            self.num_key_value_heads // weights.process_group.size()
        )

        self.q_proj = TensorParallelColumnLinear.load(
            config,
            prefix=f"{prefix}.q_proj",
            weights=weights,
            bias=False,
        )
        self.k_proj = TensorParallelColumnLinear.load(
            config,
            prefix=f"{prefix}.k_proj",
            weights=weights,
            bias=False,
        )
        self.v_proj = TensorParallelColumnLinear.load(
            config,
            prefix=f"{prefix}.v_proj",
            weights=weights,
            bias=False,
        )
        self.o_proj = TensorParallelRowLinear.load(
            config,
            prefix=f"{prefix}.o_proj",
            weights=weights,
            bias=False,
        )

        self.q_norm = MllamaTextRMSNorm.load(
            prefix=f"{prefix}.q_norm", weights=weights, eps=config.rms_norm_eps
        )
        self.k_norm = MllamaTextRMSNorm.load(
            prefix=f"{prefix}.k_norm", weights=weights, eps=config.rms_norm_eps
        )
        self.softmax_scale = self.head_size**-0.5

    def forward(
        self,
        hidden_states: torch.Tensor,
        cross_attention_states: Optional[torch.Tensor] = None,
        # past_key_value=None,
        # attention_mask: Optional[torch.Tensor] = None,
        # cache_position: Optional[torch.LongTensor] = None,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



