backends/gaudi/server/text_generation_server/models/custom_modeling/qwen2_5_vl.py [532:729]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        hidden_states = hidden_states + attn_out
        norm2_out, _ = self.norm2(hidden_states)
        mlp_out = self.mlp(norm2_out)
        hidden_states = hidden_states + mlp_out
        return hidden_states


class Qwen2_5VLPatchMerger(nn.Module):
    def __init__(self, *, prefix, config, weights):
        super().__init__()
        self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
        self.patch_merger_ln_q = FastRMSNorm.load(
            prefix=f"{prefix}.ln_q",
            weights=weights,
            eps=1e-6,
        )
        self.fc1 = TensorParallelColumnLinear.load(
            prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True
        )
        self.fc2 = TensorParallelRowLinear.load(
            prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True
        )

    def forward(self, hidden_states) -> torch.Tensor:
        hidden_states, _ = self.patch_merger_ln_q(hidden_states)
        hidden_states = hidden_states.view(-1, self.hidden_size)
        hidden_states = self.fc1(hidden_states)
        hidden_states = F.gelu(hidden_states)
        hidden_states = self.fc2(hidden_states)
        return hidden_states


class Qwen2_5VisionModel(nn.Module):
    def __init__(self, *, prefix, config, weights):
        super().__init__()

        self.spatial_merge_size = config.spatial_merge_size
        kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size]
        self.patch_embedding = nn.Conv3d(
            in_channels=config.in_channels,
            out_channels=config.hidden_size,
            kernel_size=kernel_size,
            stride=kernel_size,
            bias=False,
        )
        self.patch_embedding.weight = nn.Parameter(
            weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False
        )
        head_dim = config.hidden_size // config.num_heads

        theta = 10000.0
        dim = head_dim // 2
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        self.blocks = nn.ModuleList(
            [
                Qwen2_5VLVisionBlock(
                    prefix=f"{prefix}.blocks.{i}",
                    config=config,
                    weights=weights,
                )
                for i in range(config.depth)
            ]
        )
        self.merger = Qwen2_5VLPatchMerger(
            prefix=f"{prefix}.merger",
            config=config,
            weights=weights,
        )

        self.temporal_patch_size = config.temporal_patch_size
        self.spatial_patch_size = config.spatial_patch_size
        self.in_channels = config.in_channels
        self.embed_dim = config.hidden_size
        self.window_size = config.window_size
        self.patch_size = config.patch_size
        self.spatial_merge_unit = config.spatial_merge_size * config.spatial_merge_size
        self.fullatt_block_indexes = config.fullatt_block_indexes

    def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
        batch_size, _, hidden_size = hidden_state.shape
        class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
        hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
        return hidden_state

    def get_window_index(self, grid_thw):
        window_index: list = []
        cu_window_seqlens: list = [0]
        window_index_id = 0
        vit_merger_window_size = (
            self.window_size // self.spatial_merge_size // self.patch_size
        )

        for grid_t, grid_h, grid_w in grid_thw:
            llm_grid_h, llm_grid_w = (
                grid_h // self.spatial_merge_size,
                grid_w // self.spatial_merge_size,
            )
            index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
                grid_t, llm_grid_h, llm_grid_w
            )
            pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
            pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
            num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
            num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
            index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
            index_padded = index_padded.reshape(
                grid_t,
                num_windows_h,
                vit_merger_window_size,
                num_windows_w,
                vit_merger_window_size,
            )
            index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
                grid_t,
                num_windows_h * num_windows_w,
                vit_merger_window_size,
                vit_merger_window_size,
            )
            seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
            index_padded = index_padded.reshape(-1)
            index_new = index_padded[index_padded != -100]
            window_index.append(index_new + window_index_id)
            cu_seqlens_tmp = (
                seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
            )
            cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
            window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
        window_index = torch.cat(window_index, dim=0)

        return window_index, cu_window_seqlens

    def forward(
        self,
        pixel_values: torch.Tensor,
        grid_thw: Optional[torch.LongTensor] = None,
    ) -> torch.Tensor:

        # reshape the input tensor for processing
        shape = (
            -1,
            self.in_channels,
            self.temporal_patch_size,
            self.spatial_patch_size,
            self.spatial_patch_size,
        )
        pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype)
        hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim)
        # TODO: revisit to see if we can avoid some of these reshapes

        # find the position ids for the input tensor based on the grid_thw
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            hpos_ids = hpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            hpos_ids = hpos_ids.permute(0, 2, 1, 3)
            hpos_ids = hpos_ids.flatten()

            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            wpos_ids = wpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            wpos_ids = wpos_ids.permute(0, 2, 1, 3)
            wpos_ids = wpos_ids.flatten()
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))

        pos_ids = torch.cat(pos_ids, dim=0)

        max_grid_size = grid_thw[:, 1:].max()

        # apply the positional embeddings to the position ids
        seq = torch.arange(
            max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype
        )
        rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        window_index, cu_window_seqlens = self.get_window_index(grid_thw)
        seq_len = hidden_states.shape[0]
        patch_shape = (seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
        og_shape = (seq_len, -1)

        hidden_states = hidden_states.view(patch_shape)[window_index, :, :].view(
            og_shape
        )
        rotary_pos_emb = rotary_pos_emb.view(patch_shape)[window_index, :, :].view(
            og_shape
        )

        rotary_pos_emb = rotary_pos_emb.to(device=hidden_states.device)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



server/text_generation_server/models/custom_modeling/qwen2_5_vl.py [566:762]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        hidden_states = hidden_states + attn_out
        norm2_out, _ = self.norm2(hidden_states)
        mlp_out = self.mlp(norm2_out)
        hidden_states = hidden_states + mlp_out
        return hidden_states


class Qwen2_5VLPatchMerger(nn.Module):
    def __init__(self, *, prefix, config, weights):
        super().__init__()
        self.hidden_size = config.hidden_size * (config.spatial_merge_size**2)
        self.patch_merger_ln_q = FastRMSNorm.load(
            prefix=f"{prefix}.ln_q",
            weights=weights,
            eps=1e-6,
        )
        self.fc1 = TensorParallelColumnLinear.load(
            prefix=f"{prefix}.mlp.0", weights=weights, config=config, bias=True
        )
        self.fc2 = TensorParallelRowLinear.load(
            prefix=f"{prefix}.mlp.2", weights=weights, config=config, bias=True
        )

    def forward(self, hidden_states) -> torch.Tensor:
        hidden_states, _ = self.patch_merger_ln_q(hidden_states)
        hidden_states = hidden_states.view(-1, self.hidden_size)
        hidden_states = self.fc1(hidden_states)
        hidden_states = F.gelu(hidden_states)
        hidden_states = self.fc2(hidden_states)
        return hidden_states


class Qwen2_5VisionModel(nn.Module):
    def __init__(self, *, prefix, config, weights):
        super().__init__()

        self.spatial_merge_size = config.spatial_merge_size
        kernel_size = [config.temporal_patch_size, config.patch_size, config.patch_size]
        self.patch_embedding = nn.Conv3d(
            in_channels=config.in_channels,
            out_channels=config.hidden_size,
            kernel_size=kernel_size,
            stride=kernel_size,
            bias=False,
        )
        self.patch_embedding.weight = nn.Parameter(
            weights.get_tensor(f"{prefix}.patch_embed.proj.weight"), requires_grad=False
        )
        head_dim = config.hidden_size // config.num_heads

        theta = 10000.0
        dim = head_dim // 2
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        self.blocks = nn.ModuleList(
            [
                Qwen2_5VLVisionBlock(
                    prefix=f"{prefix}.blocks.{i}",
                    config=config,
                    weights=weights,
                )
                for i in range(config.depth)
            ]
        )
        self.merger = Qwen2_5VLPatchMerger(
            prefix=f"{prefix}.merger",
            config=config,
            weights=weights,
        )

        self.temporal_patch_size = config.temporal_patch_size
        self.spatial_patch_size = config.spatial_patch_size
        self.in_channels = config.in_channels
        self.embed_dim = config.hidden_size
        self.window_size = config.window_size
        self.patch_size = config.patch_size
        self.spatial_merge_unit = config.spatial_merge_size * config.spatial_merge_size
        self.fullatt_block_indexes = config.fullatt_block_indexes

    def apply_class_embedding(self, hidden_state: torch.Tensor) -> torch.Tensor:
        batch_size, _, hidden_size = hidden_state.shape
        class_embedding = self.class_embedding.expand(batch_size, 1, hidden_size)
        hidden_state = torch.cat([class_embedding, hidden_state], dim=1)
        return hidden_state

    def get_window_index(self, grid_thw):
        window_index: list = []
        cu_window_seqlens: list = [0]
        window_index_id = 0
        vit_merger_window_size = (
            self.window_size // self.spatial_merge_size // self.patch_size
        )

        for grid_t, grid_h, grid_w in grid_thw:
            llm_grid_h, llm_grid_w = (
                grid_h // self.spatial_merge_size,
                grid_w // self.spatial_merge_size,
            )
            index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(
                grid_t, llm_grid_h, llm_grid_w
            )
            pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
            pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
            num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
            num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
            index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
            index_padded = index_padded.reshape(
                grid_t,
                num_windows_h,
                vit_merger_window_size,
                num_windows_w,
                vit_merger_window_size,
            )
            index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
                grid_t,
                num_windows_h * num_windows_w,
                vit_merger_window_size,
                vit_merger_window_size,
            )
            seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
            index_padded = index_padded.reshape(-1)
            index_new = index_padded[index_padded != -100]
            window_index.append(index_new + window_index_id)
            cu_seqlens_tmp = (
                seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
            )
            cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
            window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
        window_index = torch.cat(window_index, dim=0)

        return window_index, cu_window_seqlens

    def forward(
        self,
        pixel_values: torch.Tensor,
        grid_thw: Optional[torch.LongTensor] = None,
    ) -> torch.Tensor:
        # reshape the input tensor for processing
        shape = (
            -1,
            self.in_channels,
            self.temporal_patch_size,
            self.spatial_patch_size,
            self.spatial_patch_size,
        )
        pixel_values = pixel_values.view(shape).to(self.patch_embedding.weight.dtype)
        hidden_states = self.patch_embedding(pixel_values).view(-1, self.embed_dim)
        # TODO: revisit to see if we can avoid some of these reshapes

        # find the position ids for the input tensor based on the grid_thw
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            hpos_ids = hpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            hpos_ids = hpos_ids.permute(0, 2, 1, 3)
            hpos_ids = hpos_ids.flatten()

            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            wpos_ids = wpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            wpos_ids = wpos_ids.permute(0, 2, 1, 3)
            wpos_ids = wpos_ids.flatten()
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))

        pos_ids = torch.cat(pos_ids, dim=0)

        max_grid_size = grid_thw[:, 1:].max()

        # apply the positional embeddings to the position ids
        seq = torch.arange(
            max_grid_size, device=self.inv_freq.device, dtype=self.inv_freq.dtype
        )
        rotary_pos_emb_full = torch.outer(seq, self.inv_freq)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        window_index, cu_window_seqlens = self.get_window_index(grid_thw)
        seq_len = hidden_states.shape[0]
        patch_shape = (seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
        og_shape = (seq_len, -1)

        hidden_states = hidden_states.view(patch_shape)[window_index, :, :].view(
            og_shape
        )
        rotary_pos_emb = rotary_pos_emb.view(patch_shape)[window_index, :, :].view(
            og_shape
        )

        rotary_pos_emb = rotary_pos_emb.to(device=hidden_states.device)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



