# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This file is heavily inspired by the original implementation from https://github.com/lucidrains/muse-maskgit-pytorch

import math
from functools import partial
from typing import Callable, Optional

import numpy as np
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
from torch.utils.checkpoint import checkpoint
from tqdm import tqdm

from .modeling_transformer_v2 import MaskGiTUViT_v2
from .modeling_utils import ConfigMixin, ModelMixin, register_to_config
from .sampling import cosine_schedule, gumbel_sample, mask_by_random_topk, top_k

try:
    import xformers.ops as xops

    is_xformers_available = True
except ImportError:
    is_xformers_available = False

MaskGiTUViT = MaskGiTUViT_v2


# classifier free guidance functions


def uniform(shape, min=0, max=1, device=None):
    return torch.zeros(shape, device=device).float().uniform_(0, 1)


def prob_mask_like(shape, prob, device=None):
    if prob == 1:
        return torch.ones(shape, device=device, dtype=torch.bool)
    elif prob == 0:
        return torch.zeros(shape, device=device, dtype=torch.bool)
    else:
        return uniform(shape, device=device) < prob


def make_attention_mask(
    query_input: torch.Tensor, key_input: torch.Tensor, pairwise_fn: Callable = torch.mul
) -> torch.Tensor:
    # [batch, len_q, len_kv]
    mask = pairwise_fn(
        # [batch, len_q] -> [batch, len_q, 1]
        torch.unsqueeze(query_input, axis=-1),
        # [batch, len_q] -> [batch, 1, len_kv]
        torch.unsqueeze(key_input, axis=-2),
    )
    # [batch, 1, len_q, len_kv]. This creates the head dim.
    mask = torch.unsqueeze(mask, axis=-3)
    return (1.0 - mask).type(torch.bool)


try:
    from apex.normalization import FusedRMSNorm as RMSNorm  # noqa
except Exception:

    class RMSNorm(nn.Module):
        def __init__(self, normalized_shape, eps=1e-6, elementwise_affine=True):
            super().__init__()
            self.elementwise_affine = elementwise_affine
            if elementwise_affine:
                self.weight = nn.Parameter(torch.ones(normalized_shape))
            self.variance_epsilon = eps

        def forward(self, input):
            input_dtype = input.dtype
            variance = input.to(torch.float32).pow(2).mean(-1, keepdim=True)
            input = input * torch.rsqrt(variance + self.variance_epsilon)

            if self.elementwise_affine:
                # convert into half-precision if necessary
                if self.weight.dtype in [torch.float16, torch.bfloat16]:
                    input = input.to(self.weight.dtype)
                input = input * self.weight
            else:
                input = input.to(input_dtype)

            return input


def sinusoidal_enocde(features, embedding_dim, max_positions=10000):
    half_dim = embedding_dim // 2
    emb = math.log(max_positions) / half_dim
    emb = (
        torch.arange(
            0,
            half_dim,
            device=features.device,
            dtype=torch.float32,
        )
        .mul(-emb)
        .exp()
    )
    emb = features[:, None] * emb[None, :]
    emb = torch.cat([emb.cos(), emb.sin()], dim=1)
    if embedding_dim % 2 == 1:  # zero pad
        emb = nn.functional.pad(emb, (0, 1), mode="constant")
    return emb


# layer norm without bias
class LayerNorm(nn.Module):
    def __init__(self, dim, eps=1e-5, use_bias=False, elementwise_affine=True):
        super().__init__()
        self.dim = dim
        if elementwise_affine:
            self.weight = nn.Parameter(torch.ones(dim))
            self.bias = nn.Parameter(torch.zeros(dim)) if use_bias else None
        else:
            self.weight = None
            self.bias = None
        self.eps = eps

    def forward(self, x):
        return F.layer_norm(x, (self.dim,), self.weight, self.bias, self.eps)


class AdaLNModulation(nn.Module):
    def __init__(self, cond_embed_dim, hidden_size, use_bias=False):
        super().__init__()
        self.mapper = nn.Linear(cond_embed_dim, hidden_size * 2, bias=use_bias)

    def forward(self, hidden_states, cond_embeds):
        cond_embeds = F.silu(cond_embeds)
        scale, shift = self.mapper(cond_embeds).chunk(2, dim=1)
        if hidden_states.dim() > 3:
            scale, shift = scale[:, :, None, None], shift[:, :, None, None]
        else:
            scale, shift = scale[:, None], shift[:, None]
        return hidden_states * (1 + scale) + shift


class Attention(nn.Module):
    def __init__(self, hidden_size, num_heads, encoder_hidden_size=None, attention_dropout=0.0, use_bias=False):
        super().__init__()

        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads
        self.attention_dropout = attention_dropout
        if self.head_dim * self.num_heads != self.hidden_size:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.hidden_size} and"
                f" `num_heads`: {self.num_heads})."
            )
        self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())

        self.query = nn.Linear(self.hidden_size, self.hidden_size, bias=use_bias)

        kv_hidden_size = self.hidden_size if encoder_hidden_size is None else encoder_hidden_size
        self.key = nn.Linear(kv_hidden_size, self.hidden_size, bias=use_bias)
        self.value = nn.Linear(kv_hidden_size, self.hidden_size, bias=use_bias)

        self.out = nn.Linear(self.hidden_size, self.hidden_size, bias=use_bias)
        self.dropout = nn.Dropout(attention_dropout)

        self.use_memory_efficient_attention_xformers = False
        self.xformers_attention_op = None

    def set_use_memory_efficient_attention_xformers(
        self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
    ):
        if use_memory_efficient_attention_xformers and not is_xformers_available:
            raise ImportError("Please install xformers to use memory efficient attention")
        self.use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
        self.xformers_attention_op = attention_op

    def forward(self, hidden_states, encoder_hidden_states=None, encoder_attention_mask=None):
        if encoder_attention_mask is not None and self.use_memory_efficient_attention_xformers:
            raise ValueError("Memory efficient attention does not yet support encoder attention mask")

        context = hidden_states if encoder_hidden_states is None else encoder_hidden_states
        batch, q_seq_len, _ = hidden_states.shape
        kv_seq_len = q_seq_len if encoder_hidden_states is None else encoder_hidden_states.shape[1]

        query = self.query(hidden_states)
        key = self.key(context)
        value = self.value(context)

        query = query.view(batch, q_seq_len, self.num_heads, self.head_dim)  # (B, T, nh, hs)
        key = key.view(batch, kv_seq_len, self.num_heads, self.head_dim)  # (B, T, nh, hs)
        value = value.view(batch, kv_seq_len, self.num_heads, self.head_dim)  # (B, T, nh, hs)

        if self.use_memory_efficient_attention_xformers:
            attn_output = xops.memory_efficient_attention(
                query, key, value, op=self.xformers_attention_op, p=self.attention_dropout if self.training else 0.0
            )
            attn_output = attn_output.view(batch, q_seq_len, self.hidden_size)
        else:
            attention_mask = None
            if encoder_attention_mask is not None:
                src_attn_mask = torch.ones(batch, q_seq_len, dtype=torch.long, device=query.device)
                attention_mask = make_attention_mask(src_attn_mask, encoder_attention_mask, dtype=query.dtype)
            attn_output = self.attention(query, key, value, attention_mask)

        attn_output = self.out(attn_output)
        return attn_output

    def attention(self, query, key, value, attention_mask=None):
        batch, seq_len = query.shape[:2]
        kv_seq_len = key.shape[1]
        query, key, value = map(lambda t: t.transpose(1, 2).contiguous(), (query, key, value))  # (B, nh, T, hs)

        attn_weights = torch.baddbmm(
            input=torch.zeros(batch * self.num_heads, seq_len, kv_seq_len, dtype=query.dtype, device=query.device),
            batch1=query.view(batch * self.num_heads, seq_len, self.head_dim),
            batch2=key.view(batch * self.num_heads, kv_seq_len, self.head_dim).transpose(1, 2),
            alpha=1 / self.scale_attn,
        )
        attn_weights = attn_weights.view(batch, self.num_heads, seq_len, kv_seq_len)  # -1 is kv_seq_len
        # Apply the attention mask
        if attention_mask is not None:
            attn_weights = torch.masked_fill(attn_weights, attention_mask, torch.finfo(query.dtype).min)
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = self.dropout(attn_weights)
        attn_output = torch.matmul(attn_weights, value)  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        # re-assemble all head outputs side by side
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch, seq_len, self.hidden_size)
        return attn_output


# U-ViT blocks
# Adpated from https://github.com/dome272/Paella/blob/main/src_distributed/modules.py


class AttentionBlock2D(nn.Module):
    def __init__(
        self,
        hidden_size,
        num_heads,
        encoder_hidden_size,
        attention_dropout=0.0,
        norm_type="layernorm",
        layer_norm_eps=1e-6,
        ln_elementwise_affine=True,
        use_bias=False,
    ):
        super().__init__()
        self.hidden_size = hidden_size

        norm_cls = partial(LayerNorm, use_bias=use_bias) if norm_type == "layernorm" else RMSNorm
        self.attn_layer_norm = norm_cls(self.hidden_size, eps=layer_norm_eps, elementwise_affine=ln_elementwise_affine)
        self.attention = Attention(hidden_size, num_heads, attention_dropout=attention_dropout, use_bias=use_bias)
        self.crossattn_layer_norm = norm_cls(hidden_size, eps=layer_norm_eps, elementwise_affine=ln_elementwise_affine)
        self.crossattention = Attention(hidden_size, num_heads, attention_dropout=attention_dropout, use_bias=use_bias)

        if encoder_hidden_size != hidden_size:
            self.kv_mapper = nn.Linear(encoder_hidden_size, hidden_size, bias=use_bias)
        else:
            self.kv_mapper = None

    def forward(self, hidden_states, encoder_hidden_states, encoder_attention_mask=None):
        # hidden_states -> (bs, hidden_size, height, width)
        # reshape to (bs, height * width, hidden_size)
        batch_size, channels, height, width = hidden_states.shape
        hidden_states = hidden_states.view(batch_size, channels, height * width).permute(0, 2, 1)

        # map encoder hidden states to hidden size of current layer
        if self.kv_mapper is not None:
            encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))

        # self attention
        residual = hidden_states
        hidden_states = self.attn_layer_norm(hidden_states)
        hidden_states = self.attention(hidden_states, encoder_hidden_states, encoder_attention_mask)
        hidden_states = hidden_states + residual

        # cross attention
        residual = hidden_states
        hidden_states = self.crossattn_layer_norm(hidden_states)
        hidden_states = self.crossattention(hidden_states, encoder_hidden_states, encoder_attention_mask)
        hidden_states = hidden_states + residual

        # reshape back to (bs, hidden_size, height, width)
        hidden_states = hidden_states.permute(0, 2, 1).view(batch_size, channels, height, width)

        return hidden_states


class Norm2D(nn.Module):
    def __init__(self, dim, eps=1e-5, use_bias=False, norm_type="layernorm", elementwise_affine=True):
        super().__init__()
        if norm_type == "layernorm":
            self.norm = LayerNorm(dim, eps, use_bias, elementwise_affine=elementwise_affine)
        elif norm_type == "rmsnorm":
            self.norm = RMSNorm(dim, eps, elementwise_affine=elementwise_affine)

    def forward(self, x):
        return self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)


class GlobalResponseNorm(nn.Module):
    "Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105"

    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
        self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))

    def forward(self, x):
        Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
        Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
        return self.gamma * (x * Nx) + self.beta + x


class ResBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        skip_channels=None,
        kernel_size=3,
        dropout=0.0,
        norm_type="layernorm",
        ln_elementwise_affine=True,
        add_cond_embeds=False,
        cond_embed_dim=None,
        use_bias=False,
        res_ffn_factor=4,
        **kwargs,
    ):
        super().__init__()
        self.depthwise = nn.Conv2d(
            in_channels + skip_channels,
            in_channels,
            kernel_size=kernel_size,
            padding=kernel_size // 2,
            groups=in_channels,
            bias=use_bias,
        )
        self.norm = Norm2D(
            in_channels, eps=1e-6, norm_type=norm_type, use_bias=use_bias, elementwise_affine=ln_elementwise_affine
        )
        self.channelwise = nn.Sequential(
            nn.Linear(in_channels, int(in_channels * res_ffn_factor), bias=use_bias),
            nn.GELU(),
            GlobalResponseNorm(int(in_channels * res_ffn_factor)),
            nn.Dropout(dropout),
            nn.Linear(int(in_channels * res_ffn_factor), in_channels, bias=use_bias),
        )

        if add_cond_embeds:
            self.adaLN_modulation = AdaLNModulation(
                cond_embed_dim=cond_embed_dim, hidden_size=in_channels, use_bias=use_bias
            )

    def forward(self, x, x_skip=None, cond_embeds=None):
        x_res = x
        if x_skip is not None:
            x = torch.cat([x, x_skip], dim=1)
        x = self.norm(self.depthwise(x)).permute(0, 2, 3, 1)
        x = self.channelwise(x).permute(0, 3, 1, 2)
        x = x + x_res
        if cond_embeds is not None:
            x = self.adaLN_modulation(x, cond_embeds)
        return x


class ResnetBlockVanilla(nn.Module):
    def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, use_bias=False, **kwargs):
        super().__init__()
        self.in_channels = in_channels
        out_channels = in_channels if out_channels is None else out_channels
        self.out_channels = out_channels
        self.use_conv_shortcut = conv_shortcut

        self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
        self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=use_bias)

        self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
        self.dropout = torch.nn.Dropout(dropout)
        self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=use_bias)

        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                self.conv_shortcut = torch.nn.Conv2d(
                    in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=use_bias
                )
            else:
                self.nin_shortcut = torch.nn.Conv2d(
                    in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=use_bias
                )

    def forward(self, hidden_states, **kwargs):
        residual = hidden_states
        hidden_states = self.norm1(hidden_states)
        hidden_states = F.silu(hidden_states)
        hidden_states = self.conv1(hidden_states)

        hidden_states = self.norm2(hidden_states)
        hidden_states = F.silu(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.conv2(hidden_states)

        if self.in_channels != self.out_channels:
            if self.use_conv_shortcut:
                residual = self.conv_shortcut(residual)
            else:
                residual = self.nin_shortcut(residual)

        return residual + hidden_states


class DownsampleBlock(nn.Module):
    def __init__(
        self,
        input_channels,
        output_channels=None,
        skip_channels=None,
        num_res_blocks=4,
        kernel_size=3,
        res_ffn_factor=4,
        dropout=0.0,
        norm_type="layernorm",
        ln_elementwise_affine=True,
        add_downsample=True,
        add_cond_embeds=False,
        cond_embed_dim=None,
        has_attention=False,
        num_heads=None,
        encoder_hidden_size=None,
        use_bias=False,
        **kwargs,
    ):
        super().__init__()
        self.add_downsample = add_downsample
        self.has_attention = has_attention
        if add_downsample:
            self.downsample = nn.Sequential(
                Norm2D(
                    input_channels,
                    eps=1e-6,
                    use_bias=use_bias,
                    norm_type=norm_type,
                    elementwise_affine=ln_elementwise_affine,
                ),
                nn.Conv2d(input_channels, output_channels, kernel_size=2, stride=2, bias=use_bias),
            )
            self.input_channels = output_channels
        else:
            self.input_channels = input_channels

        self.res_blocks = nn.ModuleList(
            [
                ResBlock(
                    self.input_channels,
                    skip_channels=skip_channels,
                    kernel_size=kernel_size,
                    dropout=dropout,
                    norm_type=norm_type,
                    ln_elementwise_affine=ln_elementwise_affine,
                    add_cond_embeds=add_cond_embeds,
                    cond_embed_dim=cond_embed_dim,
                    use_bias=use_bias,
                    res_ffn_factor=res_ffn_factor,
                )
                for _ in range(num_res_blocks)
            ]
        )

        if has_attention:
            self.attention_blocks = nn.ModuleList(
                [
                    AttentionBlock2D(
                        hidden_size=self.input_channels,
                        num_heads=num_heads,
                        encoder_hidden_size=encoder_hidden_size,
                        attention_dropout=dropout,
                        norm_type=norm_type,
                        ln_elementwise_affine=ln_elementwise_affine,
                        use_bias=use_bias,
                    )
                    for _ in range(num_res_blocks)
                ]
            )

        self.gradient_checkpointing = False

    def forward(self, x, x_skip=None, cond_embeds=None, encoder_hidden_states=None, **kwargs):
        if self.add_downsample:
            x = self.downsample(x)

        output_states = ()
        for i, res_block in enumerate(self.res_blocks):
            if self.training and self.gradient_checkpointing:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs)

                    return custom_forward

                x = torch.utils.checkpoint.checkpoint(create_custom_forward(res_block), x, x_skip)
                if self.has_attention:
                    x = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(self.attention_blocks[i]), x, encoder_hidden_states
                    )
            else:
                x = res_block(x, x_skip, cond_embeds=cond_embeds)
                if self.has_attention:
                    x = self.attention_blocks[i](x, encoder_hidden_states)

            output_states += (x,)
        return x, output_states


class UpsampleBlock(nn.Module):
    def __init__(
        self,
        input_channels,
        output_channels=None,
        skip_channels=None,
        num_res_blocks=4,
        kernel_size=3,
        res_ffn_factor=4,
        dropout=0.0,
        norm_type="layernorm",
        ln_elementwise_affine=True,
        add_upsample=True,
        add_cond_embeds=False,
        cond_embed_dim=None,
        has_attention=False,
        num_heads=None,
        encoder_hidden_size=None,
        use_bias=False,
        **kwargs,
    ):
        super().__init__()
        self.add_upsample = add_upsample
        self.has_attention = has_attention
        self.input_channels = input_channels
        self.output_channels = output_channels if output_channels is not None else input_channels

        self.res_blocks = nn.ModuleList(
            [
                ResBlock(
                    self.input_channels,
                    skip_channels=skip_channels if i == 0 else 0,
                    kernel_size=kernel_size,
                    dropout=dropout,
                    norm_type=norm_type,
                    ln_elementwise_affine=ln_elementwise_affine,
                    add_cond_embeds=add_cond_embeds,
                    cond_embed_dim=cond_embed_dim,
                    use_bias=use_bias,
                    res_ffn_factor=res_ffn_factor,
                )
                for i in range(num_res_blocks)
            ]
        )

        if has_attention:
            self.attention_blocks = nn.ModuleList(
                [
                    AttentionBlock2D(
                        hidden_size=self.input_channels,
                        num_heads=num_heads,
                        encoder_hidden_size=encoder_hidden_size,
                        attention_dropout=dropout,
                        norm_type=norm_type,
                        ln_elementwise_affine=ln_elementwise_affine,
                        use_bias=use_bias,
                    )
                    for _ in range(num_res_blocks)
                ]
            )

        if add_upsample:
            self.upsample = nn.Sequential(
                Norm2D(
                    self.input_channels,
                    eps=1e-6,
                    norm_type=norm_type,
                    use_bias=use_bias,
                    elementwise_affine=ln_elementwise_affine,
                ),
                nn.ConvTranspose2d(self.input_channels, self.output_channels, kernel_size=2, stride=2, bias=use_bias),
            )

        self.gradient_checkpointing = False

    def forward(self, x, x_skip=None, cond_embeds=None, encoder_hidden_states=None, **kwargs):
        for i, res_block in enumerate(self.res_blocks):
            x_res = x_skip[0] if i == 0 and x_skip is not None else None

            if self.training and self.gradient_checkpointing:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs)

                    return custom_forward

                x = torch.utils.checkpoint.checkpoint(create_custom_forward(res_block), x, x_res)
                if self.has_attention:
                    x = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(self.attention_blocks[i]), x, encoder_hidden_states
                    )
            else:
                x = res_block(x, x_res, cond_embeds=cond_embeds)
                if self.has_attention:
                    x = self.attention_blocks[i](x, encoder_hidden_states)

        if self.add_upsample:
            x = self.upsample(x)
        return x


class DownsampleBlockVanilla(nn.Module):
    def __init__(
        self,
        input_channels,
        output_channels=None,
        num_res_blocks=4,
        dropout=0.0,
        add_downsample=True,
        use_bias=False,
        **kwargs,
    ):
        super().__init__()
        self.add_downsample = add_downsample

        res_blocks = []
        for i in range(num_res_blocks):
            in_channels = input_channels if i == 0 else output_channels
            res_blocks.append(
                ResnetBlockVanilla(
                    in_channels=in_channels, out_channels=output_channels, dropout=dropout, use_bias=use_bias
                )
            )
        self.res_blocks = nn.ModuleList(res_blocks)

        if add_downsample:
            self.downsample_conv = nn.Conv2d(output_channels, output_channels, 3, stride=2, bias=use_bias)

        self.gradient_checkpointing = False

    def forward(self, x, **kwargs):
        output_states = ()
        for res_block in self.res_blocks:
            if self.training and self.gradient_checkpointing:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs)

                    return custom_forward

                x = torch.utils.checkpoint.checkpoint(create_custom_forward(res_block), x)
            else:
                x = res_block(x)

            output_states = output_states + (x,)

        if self.add_downsample:
            pad = (0, 1, 0, 1)
            x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
            x = self.downsample_conv(x)
            output_states = output_states + (x,)

        return x, output_states


class UpsampleBlockVanilla(nn.Module):
    def __init__(
        self,
        input_channels,
        output_channels,
        skip_channels=None,
        num_res_blocks=4,
        dropout=0.0,
        add_upsample=True,
        use_bias=False,
        **kwargs,
    ):
        super().__init__()
        self.add_upsample = add_upsample
        res_blocks = []
        for i in range(num_res_blocks):
            res_skip_channels = input_channels if (i == num_res_blocks - 1) else output_channels
            resnet_in_channels = skip_channels if i == 0 else output_channels

            res_blocks.append(
                ResnetBlockVanilla(
                    in_channels=resnet_in_channels + res_skip_channels,
                    out_channels=output_channels,
                    dropout=dropout,
                )
            )
        self.res_blocks = nn.ModuleList(res_blocks)

        if add_upsample:
            self.upsample_conv = nn.Conv2d(output_channels, output_channels, 3, padding=1)

        self.gradient_checkpointing = False

    def forward(self, x, x_skip, **kwargs):
        for res_block in self.res_blocks:
            # pop res hidden states
            res_hidden_states = x_skip[-1]
            x_skip = x_skip[:-1]
            x = torch.cat([x, res_hidden_states], dim=1)
            if self.training and self.gradient_checkpointing:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs)

                    return custom_forward

                x = torch.utils.checkpoint.checkpoint(create_custom_forward(res_block), x)
            else:
                x = res_block(x)

        if self.add_upsample:
            if x.shape[0] >= 64:
                x = x.contiguous()
            x = F.interpolate(x, scale_factor=2.0, mode="nearest")
            x = self.upsample_conv(x)

        return x


# End U-ViT blocks


# Normformer style GLU FeedForward
class FeedForward(nn.Module):
    def __init__(
        self,
        hidden_size,
        intermediate_size,
        hidden_dropout=0.0,
        norm_type="layernorm",
        layer_norm_eps=1e-5,
        ln_elementwise_affine=True,
        use_normformer=True,
        add_cond_embeds=False,
        cond_embed_dim=None,
        use_bias=False,
        ffn_type="glu",  # glu or vanilla
    ):
        super().__init__()
        self.use_normformer = use_normformer
        self.ffn_type = ffn_type
        self.pre_mlp_layer_norm = LayerNorm(
            hidden_size, eps=layer_norm_eps, use_bias=use_bias, elementwise_affine=ln_elementwise_affine
        )
        self.wi_0 = nn.Linear(hidden_size, intermediate_size, bias=use_bias)
        if ffn_type == "glu":
            self.wi_1 = nn.Linear(hidden_size, intermediate_size, bias=use_bias)
        if use_normformer:
            norm_cls = partial(LayerNorm, use_bias=use_bias) if norm_type == "layernorm" else RMSNorm
            self.mid_mlp_layer_norm = norm_cls(
                intermediate_size, eps=layer_norm_eps, elementwise_affine=ln_elementwise_affine
            )
        self.wo = nn.Linear(intermediate_size, hidden_size, bias=use_bias)
        self.dropout = nn.Dropout(hidden_dropout)
        if add_cond_embeds:
            self.adaLN_modulation = AdaLNModulation(
                cond_embed_dim=cond_embed_dim, hidden_size=hidden_size, use_bias=use_bias
            )

    def forward(self, hidden_states: torch.FloatTensor, cond_embeds=None) -> torch.FloatTensor:
        hidden_states = self.pre_mlp_layer_norm(hidden_states)
        if cond_embeds is not None:
            hidden_states = self.adaLN_modulation(hidden_states, cond_embeds)
        hidden_gelu = F.gelu(self.wi_0(hidden_states))
        if self.ffn_type == "glu":
            hidden_linear = self.wi_1(hidden_states)
            hidden_states = hidden_gelu * hidden_linear
        else:
            hidden_states = hidden_gelu
        if self.use_normformer:
            hidden_states = self.mid_mlp_layer_norm(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.wo(hidden_states)
        return hidden_states


# PreLN Transformer layer
class TransformerLayer(nn.Module):
    def __init__(
        self,
        hidden_size,
        intermediate_size,
        num_attention_heads,
        encoder_hidden_size=1024,
        add_cross_attention=False,
        hidden_dropout=0.0,
        attention_dropout=0.0,
        norm_type="layernorm",
        layer_norm_eps=1e-5,
        ln_elementwise_affine=True,
        use_normformer=True,
        add_cond_embeds=False,
        cond_embed_dim=None,
        ffn_type="glu",
        use_bias=False,
    ):
        super().__init__()

        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_attention_heads = num_attention_heads
        self.use_normformer = use_normformer

        norm_cls = partial(LayerNorm, use_bias=use_bias) if norm_type == "layernorm" else RMSNorm
        self.attn_layer_norm = norm_cls(self.hidden_size, eps=layer_norm_eps, elementwise_affine=ln_elementwise_affine)
        self.attention = Attention(
            self.hidden_size, self.num_attention_heads, attention_dropout=attention_dropout, use_bias=use_bias
        )
        if use_normformer:
            self.post_attn_layer_norm = norm_cls(
                self.hidden_size, eps=layer_norm_eps, elementwise_affine=ln_elementwise_affine
            )
        self.ffn = FeedForward(
            self.hidden_size,
            self.intermediate_size,
            hidden_dropout,
            norm_type,
            layer_norm_eps,
            ln_elementwise_affine,
            use_normformer,
            add_cond_embeds,
            cond_embed_dim,
            use_bias,
            ffn_type,
        )

        if add_cross_attention:
            self.crossattn_layer_norm = norm_cls(
                self.hidden_size, eps=layer_norm_eps, elementwise_affine=ln_elementwise_affine
            )
            self.crossattention = Attention(
                self.hidden_size, self.num_attention_heads, encoder_hidden_size, attention_dropout, use_bias
            )
            if use_normformer:
                self.post_crossattn_layer_norm = norm_cls(
                    self.hidden_size, eps=layer_norm_eps, elementwise_affine=ln_elementwise_affine
                )

        if add_cond_embeds:
            self.self_attn_adaLN_modulation = AdaLNModulation(
                cond_embed_dim=cond_embed_dim, hidden_size=hidden_size, use_bias=use_bias
            )
            if add_cross_attention:
                self.cross_attn_adaLN_modulation = AdaLNModulation(
                    cond_embed_dim=cond_embed_dim,
                    hidden_size=hidden_size,
                    use_bias=use_bias,
                )

    def forward(self, hidden_states, encoder_hidden_states=None, encoder_attention_mask=None, cond_embeds=None):
        residual = hidden_states

        hidden_states = self.attn_layer_norm(hidden_states)
        if cond_embeds is not None:
            hidden_states = self.self_attn_adaLN_modulation(hidden_states, cond_embeds)
        attention_output = self.attention(hidden_states)
        if self.use_normformer:
            attention_output = self.post_attn_layer_norm(attention_output)
        hidden_states = residual + attention_output

        if encoder_hidden_states is not None:
            residual = hidden_states
            # TODO: should norm be applied to encoder_hidden_states as well?
            hidden_states = self.crossattn_layer_norm(hidden_states)
            if cond_embeds is not None:
                hidden_states = self.cross_attn_adaLN_modulation(hidden_states, cond_embeds)
            attention_output = self.crossattention(
                hidden_states,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
            )
            if self.use_normformer:
                attention_output = self.post_crossattn_layer_norm(attention_output)
            hidden_states = residual + attention_output

        residual = hidden_states
        hidden_states = self.ffn(hidden_states, cond_embeds=cond_embeds)
        hidden_states = residual + hidden_states
        return hidden_states


class Embed(nn.Module):
    def __init__(
        self,
        vocab_size,
        embedding_size,
        hidden_size,
        hidden_dropout=0.0,
        max_position_embeddings=512,
        norm_type="layernorm",
        layer_norm_eps=1e-5,
        use_bias=False,
        layer_norm_embedddings=False,
        use_embeddings_project=False,
    ):
        super().__init__()

        self.vocab_size = vocab_size
        self.embedding_size = embedding_size
        self.hidden_size = hidden_size
        self.hidden_dropout = hidden_dropout
        self.max_position_embeddings = max_position_embeddings
        self.layer_norm_embedddings = layer_norm_embedddings
        self.use_embeddings_project = use_embeddings_project

        self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_size)
        self.position_embeddings = nn.Embedding(self.max_position_embeddings, self.embedding_size)
        self.dropout = nn.Dropout(self.hidden_dropout)

        if layer_norm_embedddings:
            norm_cls = partial(LayerNorm, use_bias=use_bias) if norm_type == "layernorm" else RMSNorm
            self.embeddings_ln = norm_cls(self.embedding_size, eps=layer_norm_eps)

        if use_embeddings_project:
            self.embedding_hidden_mapping = nn.Linear(self.embedding_size, self.hidden_size, bias=use_bias)

    def forward(self, input_ids):
        seq_length = input_ids.shape[-1]
        position_ids = torch.arange(seq_length)[None, :].to(input_ids.device)

        word_embeddings = self.word_embeddings(input_ids)
        position_embeddings = self.position_embeddings(position_ids)
        input_embeddings = word_embeddings + position_embeddings

        if self.layer_norm_embedddings:
            input_embeddings = self.embeddings_ln(input_embeddings)

        if self.use_embeddings_project:
            input_embeddings = self.embedding_hidden_mapping(input_embeddings)

        input_embeddings = self.dropout(input_embeddings)
        return input_embeddings


class MlmLayer(nn.Module):
    def __init__(
        self,
        hidden_size,
        vocab_size,
        norm_type="layernorm",
        layer_norm_eps=1e-5,
        use_mlm_layernorm=True,
        use_bias=False,
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.use_mlm_layernorm = use_mlm_layernorm
        self.mlm_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=use_bias)
        if use_mlm_layernorm:
            norm_cls = partial(LayerNorm, use_bias=use_bias) if norm_type == "layernorm" else RMSNorm
            self.mlm_ln = norm_cls(self.hidden_size, eps=layer_norm_eps)
        self.to_logits = nn.Linear(self.hidden_size, vocab_size, bias=use_bias)

    def forward(self, hidden_states):
        hidden_states = self.mlm_dense(hidden_states)
        hidden_states = F.gelu(hidden_states)
        if self.use_mlm_layernorm:
            hidden_states = self.mlm_ln(hidden_states)
        logits = self.to_logits(hidden_states)
        return logits


class ConvEmbed(nn.Module):
    def __init__(
        self,
        vocab_size,
        embedding_size,
        hidden_size,
        patch_size=2,
        max_position_embeddings=256,
        norm_type="layernorm",
        ln_elementwise_affine=True,
        layer_norm_embedddings=False,
        layer_norm_eps=1e-5,
        use_position_embeddings=True,
        use_bias=False,
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.patch_size = patch_size
        self.max_position_embeddings = max_position_embeddings
        self.use_position_embeddings = use_position_embeddings
        self.layer_norm_embedddings = layer_norm_embedddings

        self.embeddings = nn.Embedding(vocab_size, embedding_size)
        norm_cls = partial(LayerNorm, use_bias=use_bias) if norm_type == "layernorm" else RMSNorm
        self.layer_norm = norm_cls(embedding_size, eps=layer_norm_eps, elementwise_affine=ln_elementwise_affine)
        if patch_size > 1:
            self.pixel_unshuffle = nn.PixelUnshuffle(patch_size)
        self.conv = nn.Conv2d(embedding_size * (patch_size**2), hidden_size, kernel_size=1, bias=use_bias)
        if use_position_embeddings:
            self.position_embeddings = nn.Embedding(self.max_position_embeddings, hidden_size)
        if self.layer_norm_embedddings:
            self.embeddings_ln = Norm2D(
                hidden_size, eps=layer_norm_eps, norm_type=norm_type, elementwise_affine=ln_elementwise_affine
            )

    def forward(self, input_ids):
        batch_size, seq_length = input_ids.shape
        height, width = int(seq_length**0.5), int(seq_length**0.5)
        input_ids = input_ids.view(-1, height, width)
        embeddings = self.embeddings(input_ids)
        embeddings = self.layer_norm(embeddings)
        embeddings = embeddings.permute(0, 3, 1, 2)
        if self.patch_size > 1:
            embeddings = self.pixel_unshuffle(embeddings)
        embeddings = self.conv(embeddings)
        if self.use_position_embeddings:
            embeddings = embeddings.permute(0, 2, 3, 1).view(batch_size, -1, self.hidden_size)
            position_ids = torch.arange(embeddings.shape[1])[None, :].to(input_ids.device)
            position_embeddings = self.position_embeddings(position_ids)
            embeddings = embeddings + position_embeddings
        if self.layer_norm_embedddings:
            embeddings = self.embeddings_ln(embeddings)
        return embeddings


class ConvMlmLayer(nn.Module):
    def __init__(
        self,
        vocab_size,
        embedding_size,
        hidden_size,
        patch_size=2,
        norm_type="layernorm",
        ln_elementwise_affine=True,
        layer_norm_eps=1e-5,
        use_bias=False,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.patch_size = patch_size
        self.conv1 = nn.Conv2d(hidden_size, embedding_size * (patch_size**2), kernel_size=1, bias=use_bias)
        if patch_size > 1:
            self.pixel_shuffle = nn.PixelShuffle(patch_size)
        self.layer_norm = Norm2D(
            embedding_size,
            norm_type=norm_type,
            eps=layer_norm_eps,
            use_bias=use_bias,
            elementwise_affine=ln_elementwise_affine,
        )
        self.conv2 = nn.Conv2d(embedding_size, vocab_size, kernel_size=1, bias=use_bias)

    def forward(self, hidden_states):
        batch_size, seq_length, hidden_size = hidden_states.shape
        height, width = int(seq_length**0.5), int(seq_length**0.5)
        hidden_states = hidden_states.view(batch_size, height, width, hidden_size).permute(0, 3, 1, 2)
        hidden_states = self.conv1(hidden_states)
        if self.patch_size > 1:
            hidden_states = self.pixel_shuffle(hidden_states)
        hidden_states = self.layer_norm(hidden_states)
        logits = self.conv2(hidden_states)
        logits = logits.permute(0, 2, 3, 1).view(batch_size, -1, self.vocab_size)
        return logits


class MaskGitTransformer(ModelMixin, ConfigMixin):
    _supports_gradient_checkpointing = True

    @register_to_config
    def __init__(
        self,
        vocab_size,  # codebook_size + 1 (for the mask token), for class-conditioned generation it'll be codebook_size + num_classes + 1
        hidden_size=768,
        embedding_size=None,
        num_hidden_layers=12,
        num_attention_heads=12,
        intermediate_size=3072,
        hidden_dropout=0.1,
        attention_dropout=0.1,
        max_position_embeddings=256,  # for clas-conditioned generation it'll be 256 + 1 (for the class token)
        add_cross_attention=False,
        encoder_hidden_size=1024,  # T5-large
        project_encoder_hidden_states=False,
        initializer_range=0.02,
        norm_type="layernorm",  # or rmsnorm
        layer_norm_eps=1e-5,
        use_normformer=True,
        use_encoder_layernorm=True,
        use_mlm_layer=True,
        use_mlm_layernorm=True,
        use_bias=False,
        codebook_size=1024,
        num_vq_tokens=256,
        num_classes=None,  # set for class-conditioned generation
        use_codebook_size_for_output=False,
        use_conv_in_out=False,
        patch_size=1,
        **kwargs,
    ):
        super().__init__()
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.intermediate_size = intermediate_size
        self.hidden_dropout = hidden_dropout
        self.attention_dropout = attention_dropout
        self.max_position_embeddings = max_position_embeddings
        self.initializer_range = initializer_range
        self.embedding_size = embedding_size or hidden_size
        self.register_to_config(mask_token_id=vocab_size - 1)

        norm_cls = partial(LayerNorm, use_bias=use_bias) if norm_type == "layernorm" else RMSNorm

        if use_conv_in_out:
            self.embed = ConvEmbed(
                vocab_size,
                embedding_size,
                hidden_size,
                patch_size=patch_size,
                norm_type=norm_type,
                layer_norm_eps=layer_norm_eps,
                use_bias=use_bias,
            )
        else:
            self.embed = Embed(
                self.vocab_size,
                self.hidden_size,
                self.hidden_size,
                self.hidden_dropout,
                self.max_position_embeddings,
                use_bias,
                norm_type,
                layer_norm_eps,
            )

        if add_cross_attention is not None and project_encoder_hidden_states:  # Cross attention
            self.encoder_proj = nn.Linear(encoder_hidden_size, hidden_size, bias=use_bias)
            self.encoder_proj_layer_norm = norm_cls(hidden_size, eps=layer_norm_eps)
            encoder_hidden_size = hidden_size

        self.transformer_layers = nn.ModuleList(
            [
                TransformerLayer(
                    hidden_size=self.hidden_size,
                    intermediate_size=self.intermediate_size,
                    num_attention_heads=self.num_attention_heads,
                    encoder_hidden_size=encoder_hidden_size,
                    add_cross_attention=add_cross_attention,
                    hidden_dropout=self.hidden_dropout,
                    attention_dropout=self.attention_dropout,
                    norm_type=norm_type,
                    layer_norm_eps=layer_norm_eps,
                    use_normformer=use_normformer,
                    use_bias=use_bias,
                )
                for _ in range(self.num_hidden_layers)
            ]
        )
        if use_encoder_layernorm:
            self.encoder_layer_norm = norm_cls(self.hidden_size, eps=layer_norm_eps)

        self.output_size = codebook_size if use_codebook_size_for_output else self.vocab_size
        if use_mlm_layer:
            if use_conv_in_out:
                self.mlm_layer = ConvMlmLayer(
                    self.output_size,
                    embedding_size,
                    hidden_size,
                    patch_size=patch_size,
                    norm_type=norm_type,
                    layer_norm_eps=layer_norm_eps,
                    use_bias=use_bias,
                )
            else:
                self.mlm_layer = MlmLayer(
                    self.hidden_size, self.output_size, norm_type, layer_norm_eps, use_mlm_layernorm, use_bias
                )
        else:
            self.to_logits = nn.Linear(self.hidden_size, self.output_size, bias=use_bias)

        self.gradient_checkpointing = False

        self.apply(self._init_weights)

    def _init_weights(self, module):
        """
        Initialize the weights according to the original implementation.
        https://github.com/google-research/maskgit/blob/main/maskgit/nets/maskgit_transformer.py#L37
        """
        # TODO: make this configurable
        if isinstance(module, nn.Linear):
            nn.init.trunc_normal_(module.weight, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            nn.init.trunc_normal_(module.weight, std=self.config.initializer_range)
        elif isinstance(module, (nn.LayerNorm, RMSNorm)):
            if hasattr(module, "weight") and module.weight is not None:
                module.weight.data.fill_(1.0)
            if hasattr(module, "bias") and module.bias is not None:
                module.bias.data.zero_()

    def _set_gradient_checkpointing(self, module, value=False):
        self.gradient_checkpointing = True

    def forward(
        self,
        input_ids,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        label_smoothing=0.0,
        cond_dropout_prob=0.0,
        **kwargs,
    ):
        if self.config.add_cross_attention and encoder_hidden_states is None:
            raise ValueError("If `add_cross_attention` is True, `encoder_hidden_states` should be provided.")

        hidden_states = self.embed(input_ids)

        if encoder_hidden_states is not None and self.config.project_encoder_hidden_states:
            encoder_hidden_states = self.encoder_proj(encoder_hidden_states)
            encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states)

        # condition dropout for classifier free guidance
        if encoder_hidden_states is not None and self.training and cond_dropout_prob > 0.0:
            batch_size = encoder_hidden_states.shape[0]
            mask = prob_mask_like((batch_size, 1, 1), 1.0 - cond_dropout_prob, encoder_hidden_states.device)
            encoder_hidden_states = encoder_hidden_states * mask

        for layer in self.transformer_layers:
            if self.gradient_checkpointing:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs)

                    return custom_forward

                hidden_states = checkpoint(
                    create_custom_forward(layer), hidden_states, encoder_hidden_states, encoder_attention_mask
                )
            else:
                hidden_states = layer(
                    hidden_states,
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_attention_mask,
                )

        if self.config.use_encoder_layernorm:
            hidden_states = self.encoder_layer_norm(hidden_states)

        if self.config.use_mlm_layer:
            logits = self.mlm_layer(hidden_states)
        else:
            logits = self.to_logits(hidden_states)

        if labels is not None:
            loss = F.cross_entropy(
                logits.view(-1, self.output_size), labels.view(-1), ignore_index=-100, label_smoothing=label_smoothing
            )
            return logits, loss
        return logits

    def generate(
        self,
        input_ids: torch.LongTensor = None,
        class_ids: torch.LongTensor = None,
        encoder_hidden_states: torch.FloatTensor = None,
        temperature=1.0,
        topk_filter_thres=0.9,
        can_remask_prev_masked=False,  # TODO: implement this
        timesteps=18,  # ideal number of steps is 18 in maskgit paper
        guidance_scale=3,
        noise_schedule: Callable = cosine_schedule,
        use_tqdm=True,
    ):
        # begin with all image token ids masked
        mask_token_id = self.config.mask_token_id
        seq_len = self.config.num_vq_tokens

        batch_size = len(class_ids) if class_ids is not None else encoder_hidden_states.shape[0]
        shape = (batch_size, seq_len)

        # shift the class ids by the codebook size
        if class_ids is not None:
            class_ids += self.config.codebook_size
        # initialize with all image tokens masked
        if input_ids is not None:
            input_ids = torch.ones(shape, dtype=torch.long, device=self.device) * mask_token_id
        scores = torch.zeros(shape, dtype=torch.float32, device=self.device)

        starting_temperature = temperature

        iterate_over = zip(torch.linspace(0, 1, timesteps, device=self.device), reversed(range(timesteps)))

        if use_tqdm:
            iterate_over = tqdm(iterate_over, total=timesteps)

        for timestep, steps_until_x0 in iterate_over:
            rand_mask_prob = noise_schedule(timestep)
            num_token_masked = max(int((rand_mask_prob * seq_len).item()), 1)

            masked_indices = scores.topk(num_token_masked, dim=-1).indices
            input_ids = input_ids.scatter(1, masked_indices, mask_token_id)

            # prepend class token to input_ids
            if class_ids is not None:
                input_ids = torch.cat([class_ids[:, None], input_ids], dim=1)

            # classifier free guidance
            if encoder_hidden_states is not None and guidance_scale > 0:
                uncond_encoder_states = torch.zeros_like(encoder_hidden_states)
                model_input = torch.cat([input_ids] * 2)
                condition = torch.cat([encoder_hidden_states, uncond_encoder_states])
                cond_logits, uncond_logits = self(model_input, encoder_hidden_states=condition).chunk(2)
                cond_logits = cond_logits[..., : self.config.codebook_size]
                uncond_logits = uncond_logits[..., : self.config.codebook_size]
                logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
            else:
                logits = self(input_ids, encoder_hidden_states=encoder_hidden_states)
                logits = logits[..., : self.config.codebook_size]

            # remove class token
            if class_ids is not None:
                input_ids = input_ids[:, 1:]
                logits = logits[:, 1:]

            filtered_logits = top_k(logits, topk_filter_thres)

            temperature = starting_temperature * (steps_until_x0 / timesteps)  # temperature is annealed

            pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1)

            is_mask = input_ids == mask_token_id

            input_ids = torch.where(is_mask, pred_ids, input_ids)

            probs_without_temperature = F.softmax(logits, dim=-1)

            scores = 1 - probs_without_temperature.gather(2, pred_ids[..., None])
            scores = rearrange(scores, "... 1 -> ...")  # TODO: use torch
        return input_ids

    def generate2(
        self,
        input_ids: torch.LongTensor = None,
        class_ids: torch.LongTensor = None,
        encoder_hidden_states: torch.FloatTensor = None,
        negative_embeds: torch.FloatTensor = None,
        temperature=1.0,
        timesteps=18,  # ideal number of steps is 18 in maskgit paper
        guidance_scale=0,
        noise_schedule=cosine_schedule,
        generator: torch.Generator = None,
        **kwargs,
    ):
        """
        Generate 1:1 similar to the original MaskGit repo
        https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79
        """
        # begin with all image token ids masked
        mask_token_id = self.config.mask_token_id
        seq_len = self.config.num_vq_tokens

        batch_size = len(class_ids) if class_ids is not None else encoder_hidden_states.shape[0]
        shape = (batch_size, seq_len)

        # shift the class ids by the codebook size
        if class_ids is not None:
            class_ids += self.config.codebook_size

        # initialize with all image tokens masked
        if input_ids is None:
            input_ids = torch.ones(shape, dtype=torch.long, device=self.device) * mask_token_id

        # classifier free guidance
        if encoder_hidden_states is not None and guidance_scale > 0:
            if negative_embeds is None:
                uncond_encoder_states = torch.zeros_like(encoder_hidden_states)
            else:
                uncond_encoder_states = negative_embeds
            condition = torch.cat([encoder_hidden_states, uncond_encoder_states])
            model_conds = {"encoder_hidden_states": condition}

        for step in range(timesteps):
            # prepend class token to input_ids
            if class_ids is not None:
                input_ids = torch.cat([class_ids[:, None], input_ids], dim=1)

            if encoder_hidden_states is not None and guidance_scale > 0:
                model_input = torch.cat([input_ids] * 2)
                cond_logits, uncond_logits = self(model_input, **model_conds).chunk(2)
                cond_logits = cond_logits[..., : self.config.codebook_size]
                uncond_logits = uncond_logits[..., : self.config.codebook_size]
                logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits)
            else:
                logits = self(input_ids, encoder_hidden_states=encoder_hidden_states)
                logits = logits[..., : self.config.codebook_size]

            # remove class token
            if class_ids is not None:
                input_ids = input_ids[:, 1:]
                logits = logits[:, 1:]

            # Samples the ids using categorical sampling: [batch_size, seq_length].
            probs = logits.softmax(dim=-1)
            sampled = probs.reshape(-1, logits.size(-1))
            sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1])

            # Just updates the masked tokens.
            unknown_map = input_ids == mask_token_id
            sampled_ids = torch.where(unknown_map, sampled_ids, input_ids)
            # Defines the mask ratio for the next round. The number to mask out is
            # determined by mask_ratio * unknown_number_in_the_beginning.
            ratio = 1.0 * (step + 1) / timesteps
            mask_ratio = noise_schedule(torch.tensor(ratio))
            # Computes the probabilities of each selected tokens.
            selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None])
            selected_probs = selected_probs.squeeze(-1)

            # Ignores the tokens given in the input by overwriting their confidence.
            selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max)
            # Gets mask lens for each sample in the batch according to the mask ratio.
            mask_len = (seq_len * mask_ratio).floor().unsqueeze(0).to(logits.device)
            # Keeps at least one of prediction in this round and also masks out at least
            # one and for the next iteration
            mask_len = torch.max(
                torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len)
            )

            # Adds noise for randomness
            temperature = temperature * (1.0 - ratio)
            masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator)
            # Masks tokens with lower confidence.
            input_ids = torch.where(masking, mask_token_id, sampled_ids)

        return sampled_ids
