def tiled_decode_gaudi()

in optimum/habana/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py [0:0]


def tiled_decode_gaudi(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
    r"""
    Adapted from: https://github.com/huggingface/diffusers/blob/v0.31.0/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py#L1374
    Decode a batch of images using a tiled decoder.

    Args:
        z (`torch.Tensor`): Input batch of latent vectors.
        return_dict (`bool`, *optional*, defaults to `True`):
            Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.

    Returns:
        [`~models.vae.DecoderOutput`] or `tuple`:
            If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
            returned.
    """
    # Rough memory assessment:
    #   - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
    #   - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
    #   - Assume fp16 (2 bytes per value).
    # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
    #
    # Memory assessment when using tiling:
    #   - Assume everything as above but now HxW is 240x360 by tiling in half
    # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB

    print("run gaudi pipelined tiled decode!")
    batch_size, num_channels, num_frames, height, width = z.shape

    overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
    overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
    blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
    blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
    row_limit_height = self.tile_sample_min_height - blend_extent_height
    row_limit_width = self.tile_sample_min_width - blend_extent_width
    frame_batch_size = self.num_latent_frames_batch_size

    import habana_frameworks.torch.core as htcore

    # Split z into overlapping tiles and decode them separately.
    # The tiles have an overlap to avoid seams between tiles.
    rows = []
    for i in range(0, height, overlap_height):
        row = []
        for j in range(0, width, overlap_width):
            num_batches = max(num_frames // frame_batch_size, 1)
            conv_cache = None
            time = []

            for k in range(num_batches):
                remaining_frames = num_frames % frame_batch_size
                start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
                end_frame = frame_batch_size * (k + 1) + remaining_frames
                tile = z[
                    :,
                    :,
                    start_frame:end_frame,
                    i : i + self.tile_latent_min_height,
                    j : j + self.tile_latent_min_width,
                ].clone()
                if self.post_quant_conv is not None:
                    tile = self.post_quant_conv(tile)
                tile, conv_cache = self.decoder(tile, conv_cache=conv_cache)
                time.append(tile.clone())
                htcore.mark_step()

            row.append(torch.cat(time, dim=2))
        rows.append(row)

    result_rows = []
    for i, row in enumerate(rows):
        result_row = []
        for j, tile in enumerate(row):
            # blend the above tile and the left tile
            # to the current tile and add the current tile to the result row
            if i > 0:
                tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
            if j > 0:
                tile = self.blend_h(row[j - 1], tile, blend_extent_width)
            result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
        result_rows.append(torch.cat(result_row, dim=4))

    dec = torch.cat(result_rows, dim=3)

    if not return_dict:
        return (dec,)

    return DecoderOutput(sample=dec)