# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This file is heavily inspired by the original implementation from https://github.com/lucidrains/muse-maskgit-pytorch

import dataclasses
import math
import numbers
import warnings
from dataclasses import dataclass
from typing import Callable, Optional, Tuple

import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.checkpoint import checkpoint

from .modeling_utils import ConfigMixin, ModelMixin
from .sampling import cosine_schedule, mask_by_random_topk

try:
    import xformers.ops as xops

    is_xformers_available = True
except ImportError:
    is_xformers_available = False

try:
    from flash_attn.ops.rms_norm import dropout_add_rms_norm
except ImportError:
    dropout_add_rms_norm = None

try:
    from flash_attn.ops.layer_norm import dropout_add_layer_norm
except ImportError:
    dropout_add_layer_norm = None

try:
    from flash_attn.ops.fused_dense import fused_mlp_func
except ImportError:
    fused_mlp_func = None

warnings.simplefilter("once", UserWarning)


def sinusoidal_encode(features, embedding_dim, max_positions=10000):
    half_dim = embedding_dim // 2
    emb = math.log(max_positions) / half_dim
    emb = (
        torch.arange(
            0,
            half_dim,
            device=features.device,
            dtype=torch.float32,
        )
        .mul(-emb)
        .exp()
    )
    emb = features[:, None] * emb[None, :]
    emb = torch.cat([emb.cos(), emb.sin()], dim=1)
    if embedding_dim % 2 == 1:  # zero pad
        emb = nn.functional.pad(emb, (0, 1), mode="constant")
    return emb


@dataclass
class MaskGiTUViT_v2Config:
    # global config
    hidden_size: int = 1024
    use_bias: bool = False
    hidden_dropout: float = 0.0

    # conditioning dimensions
    cond_embed_dim: int = 768
    micro_cond_encode_dim: int = 256
    micro_cond_embed_dim: int = 1280
    encoder_hidden_size: int = 768

    # num tokens
    vocab_size: int = 8256  # codebook_size + 1 (for the mask token) rounded
    mask_token_id: int = 8255
    codebook_size: int = 8192

    # `DownsampleBlock` and `UpsampleBlock`
    in_channels: int = 768
    block_out_channels: Tuple[int] = (768,)
    num_res_blocks: int = 3
    force_down_up_sample: bool = False
    block_num_heads: int = 12

    # `TransformerLayer`
    num_hidden_layers: int = 22
    num_attention_heads: int = 16

    # `Attention`
    attention_dropout: float = 0.0

    # `FeedForward`
    intermediate_size: int = 2816
    use_fused_mlp: bool = False

    # `Norm`
    norm_type: str = "rmsnorm"
    layer_norm_eps: float = 1e-6
    ln_elementwise_affine: bool = True
    use_fused_residual_norm: bool = False

    # Legacy: kept for compatibility with pipeline
    add_cond_embeds: bool = True
    add_micro_cond_embeds: bool = True


def config_from_legacy_kwargs(**kwargs):
    if "block_num_heads" in kwargs:
        if isinstance(kwargs["block_num_heads"], (tuple, list)):
            assert len(kwargs["block_num_heads"]) == 1
            kwargs["block_num_heads"] = kwargs["block_num_heads"][0]
        elif isinstance(kwargs["block_num_heads"], int):
            ...
        else:
            assert False

    config = {}

    # select only values that are expected to be in the config
    for field in dataclasses.fields(MaskGiTUViT_v2Config):
        if field.name in kwargs:
            config[field.name] = kwargs[field.name]

    # set default config values
    config = MaskGiTUViT_v2Config(**config)
    config.block_out_channels = list(config.block_out_channels)

    return config


class MaskGiTUViT_v2(ModelMixin, ConfigMixin):
    _supports_gradient_checkpointing = True

    def __init__(self, **kwargs):
        super().__init__()

        config = config_from_legacy_kwargs(**kwargs)
        self.register_to_config(**dataclasses.asdict(config))
        self.register_to_config(mask_token_id=self.config.vocab_size - 1)

        # TODO: Allow enabling fused norm using a function (like we do for xformers attention)
        if self.config.use_fused_residual_norm and dropout_add_layer_norm is None:
            warnings.warn("Cannot use fused layer norm. Please install flash_attn. Falling back to unfused layer norm", UserWarning)
            self.register_to_config(use_fused_residual_norm=False)

        assert len(self.config.block_out_channels) == 1

        # Legacy: kept for compatibility with pipeline
        self.output_size = self.config.codebook_size

        self.encoder_proj = nn.Linear(
            self.config.encoder_hidden_size, self.config.hidden_size, bias=self.config.use_bias
        )
        self.encoder_proj_layer_norm = Norm(self.config.hidden_size, self.config)

        self.embed = ConvEmbed(self.config)

        self.cond_embed = nn.Sequential(
            nn.Linear(
                self.config.micro_cond_embed_dim + self.config.cond_embed_dim,
                self.config.hidden_size,
                bias=self.config.use_bias,
            ),
            nn.SiLU(),
            nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=self.config.use_bias),
        )

        self.down_blocks = nn.ModuleList([DownsampleBlock(self.config.block_out_channels[0], self.config)])

        self.project_to_hidden_norm = Norm(self.config.block_out_channels[-1], self.config)
        self.project_to_hidden = nn.Linear(
            self.config.block_out_channels[-1], self.config.hidden_size, bias=self.config.use_bias
        )

        self.transformer_layers = nn.ModuleList(
            [TransformerLayer(self.config) for _ in range(self.config.num_hidden_layers)]
        )

        self.project_from_hidden_norm = Norm(self.config.hidden_size, self.config)
        self.project_from_hidden = nn.Linear(
            self.config.hidden_size, self.config.block_out_channels[-1], bias=self.config.use_bias
        )

        self.up_blocks = nn.ModuleList([UpsampleBlock(self.config.block_out_channels[0], self.config)])

        self.mlm_layer = ConvMlmLayer(self.config)

        self.gradient_checkpointing = False

        # --- WEIGHT INIT ---
        self.apply(self._init_weights)  # General init
        nn.init.xavier_uniform_(self.embed.conv.weight, 0.02)  # inputs
        nn.init.normal_(self.embed.embeddings.weight, std=np.sqrt(1 / self.config.vocab_size))
        nn.init.constant_(self.mlm_layer.conv1.weight, 0)  # output
        self.mlm_layer.conv2.weight.data = self.embed.embeddings.weight.data[
            : self.config.codebook_size, :, None, None
        ].clone()

        # init AdaLNModulation.mapper layers to 0
        for m in self.modules():
            if isinstance(m, AdaLNModulation):
                nn.init.constant_(m.mapper.weight, 0)
                if hasattr(m, "bias") and m.bias is not None:
                    nn.init.constant_(m.bias, 0)

    def _init_weights(self, module):
        """
        Initialize the weights according to the original implementation.
        https://github.com/google-research/maskgit/blob/main/maskgit/nets/maskgit_transformer.py#L37
        """
        if isinstance(module, (nn.Linear, nn.Conv2d)):
            nn.init.trunc_normal_(module.weight, std=0.02)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            nn.init.trunc_normal_(module.weight, std=0.02)
        elif isinstance(module, (LayerNorm, RMSNorm)):
            if hasattr(module, "weight") and module.weight is not None:
                module.weight.data.fill_(1.0)
            if hasattr(module, "bias") and module.bias is not None:
                module.bias.data.zero_()

    def forward(
        self,
        input_ids,
        encoder_hidden_states,
        cond_embeds,
        micro_conds,
        labels=None,
        label_smoothing=0.0,
        loss_weight=None,
    ):
        encoder_hidden_states = self.encoder_proj(encoder_hidden_states)
        encoder_hidden_states, _ = self.encoder_proj_layer_norm(encoder_hidden_states)

        micro_cond_embeds = sinusoidal_encode(micro_conds.flatten(), self.config.micro_cond_encode_dim)
        micro_cond_embeds = micro_cond_embeds.reshape((input_ids.shape[0], -1))

        cond_embeds = torch.cat([cond_embeds, micro_cond_embeds], dim=1)
        cond_embeds = cond_embeds.to(dtype=self.dtype)
        cond_embeds = self.cond_embed(cond_embeds).to(encoder_hidden_states.dtype)

        hidden_states = self.embed(input_ids)

        hidden_states = self.down_blocks[0](
            hidden_states, cond_embeds=cond_embeds, encoder_hidden_states=encoder_hidden_states
        )

        batch_size, channels, height, width = hidden_states.shape
        hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels)

        hidden_states, _ = self.project_to_hidden_norm(hidden_states)
        hidden_states = self.project_to_hidden(hidden_states)

        transformer_residual = None

        for layer in self.transformer_layers:
            if self.training and self.gradient_checkpointing:
                layer_ = lambda *args: checkpoint(layer, *args)
            else:
                layer_ = layer

            hidden_states, transformer_residual = layer_(
                hidden_states,
                encoder_hidden_states,
                cond_embeds,
                transformer_residual,
            )

        hidden_states = hidden_states + transformer_residual

        hidden_states, _ = self.project_from_hidden_norm(hidden_states)
        hidden_states = self.project_from_hidden(hidden_states)

        hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)

        assert len(self.up_blocks) == 1
        hidden_states = self.up_blocks[0](
            hidden_states, cond_embeds=cond_embeds, encoder_hidden_states=encoder_hidden_states
        )

        batch_size, channels, height, width = hidden_states.shape
        hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels)
        logits = self.mlm_layer(hidden_states)

        if labels is not None:
            reduction = "none" if loss_weight is not None else "mean"
            loss = F.cross_entropy(
                logits.view(-1, self.codebook_size),
                labels.view(-1),
                ignore_index=-100,
                label_smoothing=label_smoothing,
                reduction=reduction,
            )
            if loss_weight is not None:
                loss_weight = loss_weight.view(-1)
                loss = ((loss * loss_weight).sum(dim=-1) / loss_weight.sum(dim=-1)).mean()
            return logits, loss

        return logits

    def _set_gradient_checkpointing(self, module, value=False):
        self.gradient_checkpointing = value
        if isinstance(module, (DownsampleBlock, UpsampleBlock)):
            module.gradient_checkpointing = value

    # Legacy: kept for compatibility with pipeline
    def generate(self):
        assert False

    def generate2(
        self,
        encoder_hidden_states: torch.FloatTensor,
        cond_embeds: torch.FloatTensor,
        micro_conds: torch.FloatTensor,
        empty_embeds: torch.FloatTensor,
        empty_cond_embeds: torch.FloatTensor,
        input_ids: torch.LongTensor = None,
        negative_embeds: torch.FloatTensor = None,
        negative_cond_embeds: torch.FloatTensor = None,
        temperature=1.0,
        timesteps=18,  # ideal number of steps is 18 in maskgit paper
        guidance_scale=0,
        guidance_schedule=None,
        noise_schedule=cosine_schedule,
        generator: torch.Generator = None,
        return_intermediate=False,
        seq_len=None,
        use_tqdm=None,
        # Legacy: kept for compatibility with pipeline
        topk_filter_thres=None,
        noise_type=None,
        predict_all_tokens=None,
    ):
        batch_size = encoder_hidden_states.shape[0]

        if seq_len is None:
            seq_len = 256

        shape = (batch_size, seq_len)

        if isinstance(temperature, tuple):
            temperatures = torch.linspace(temperature[0], temperature[1], timesteps)
        else:
            temperatures = torch.linspace(temperature, 0.01, timesteps)

        if input_ids is None:
            input_ids = torch.ones(shape, dtype=torch.long, device=self.device) * self.config.mask_token_id

        if return_intermediate:
            intermediate = []

        if guidance_schedule == "linear":
            guidance_scales = torch.linspace(0, guidance_scale, timesteps)
        elif guidance_schedule == "cosine":
            guidance_scales = []
            for step in range(timesteps):
                ratio = 1.0 * (step + 1) / timesteps
                scale = cosine_schedule(torch.tensor(1 - ratio)) * guidance_scale
                guidance_scales.append(scale.floor())
            guidance_scales = torch.tensor(guidance_scales)
        else:
            guidance_scales = torch.ones(timesteps) * guidance_scale

        if micro_conds.shape[0] == 1:
            micro_conds = micro_conds.repeat(batch_size, 1).to(input_ids.device)

        if guidance_scale > 0:
            # encoder_hidden_states

            if negative_embeds is None:
                uncond_encoder_states = empty_embeds
            else:
                uncond_encoder_states = negative_embeds

            if uncond_encoder_states.shape[0] == 1:
                uncond_encoder_states = uncond_encoder_states.expand(batch_size, -1, -1)

            encoder_hidden_states = torch.cat([encoder_hidden_states, uncond_encoder_states])

            # cond_embeds

            if negative_cond_embeds is None:
                uncond_embeds = empty_cond_embeds
            else:
                uncond_embeds = negative_cond_embeds

            if uncond_embeds.shape[0] == 1:
                uncond_embeds = uncond_embeds.expand(batch_size, -1)

            cond_embeds = torch.cat([cond_embeds, uncond_embeds])

            # micro_conds

            micro_conds = torch.cat([micro_conds, micro_conds], dim=0)

        if use_tqdm:
            from tqdm.auto import tqdm
            timesteps_iter = tqdm(range(timesteps))
        else:
            timesteps_iter = range(timesteps)
        
        for step in timesteps_iter:
            if guidance_scale > 0:
                model_input = torch.cat([input_ids] * 2)

            model_output = self(
                model_input,
                micro_conds=micro_conds,
                cond_embeds=cond_embeds,
                encoder_hidden_states=encoder_hidden_states,
            )

            if guidance_scale > 0:
                cond_logits, uncond_logits = model_output.chunk(2)
                cond_logits = cond_logits[..., : self.config.codebook_size]
                uncond_logits = uncond_logits[..., : self.config.codebook_size]
                logits = uncond_logits + guidance_scales[step] * (cond_logits - uncond_logits)
            else:
                logits = model_output

            logits = logits[..., : self.config.codebook_size]

            probs = logits.softmax(dim=-1)

            sampled = probs.reshape(-1, logits.size(-1))
            sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1])

            if return_intermediate:
                intermediate.append(sampled_ids)

            unknown_map = input_ids == self.config.mask_token_id
            sampled_ids = torch.where(unknown_map, sampled_ids, input_ids)
            # Defines the mask ratio for the next round. The number to mask out is
            # determined by mask_ratio * unknown_number_in_the_beginning.
            ratio = 1.0 * (step + 1) / timesteps
            mask_ratio = noise_schedule(torch.tensor(ratio))

            # Gets mask lens for each sample in the batch according to the mask ratio.
            mask_len = (seq_len * mask_ratio).floor().unsqueeze(0).to(logits.device)
            # Keeps at least one of prediction in this round and also masks out at least
            # one and for the next iteration
            mask_len = torch.max(
                torch.tensor([1], device=logits.device),
                torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len),
            )

            selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None])
            selected_probs = selected_probs.squeeze(-1)
            # Ignores the tokens given in the input by overwriting their confidence.
            selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max)
            temperature = temperatures[step]
            masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator)
            # Masks tokens with lower confidence.
            input_ids = torch.where(masking, self.config.mask_token_id, sampled_ids)

        if return_intermediate:
            return sampled_ids, intermediate

        return sampled_ids


# embedding blocks


class ConvEmbed(nn.Module):
    def __init__(self, config: MaskGiTUViT_v2Config):
        super().__init__()
        self.embeddings = nn.Embedding(config.vocab_size, config.in_channels)
        self.layer_norm = Norm(config.in_channels, config)
        self.conv = nn.Conv2d(config.in_channels, config.block_out_channels[0], kernel_size=1, bias=config.use_bias)

    def forward(self, input_ids):
        batch_size, seq_length = input_ids.shape
        height, width = int(seq_length**0.5), int(seq_length**0.5)
        input_ids = input_ids.view(-1, height, width)
        embeddings = self.embeddings(input_ids)
        embeddings, _ = self.layer_norm(embeddings)
        embeddings = embeddings.permute(0, 3, 1, 2)
        embeddings = self.conv(embeddings)
        return embeddings


# down/upsample blocks


class DownsampleBlock(nn.Module):
    def __init__(self, channels, config: MaskGiTUViT_v2Config):
        super().__init__()

        if config.force_down_up_sample:
            self.downsample = nn.Sequential(
                Norm2D(channels, config),
                nn.Conv2d(channels, channels, kernel_size=2, stride=2, bias=config.use_bias),
            )
        else:
            self.downsample = None

        self.res_blocks = nn.ModuleList([ResBlock(channels, config) for _ in range(config.num_res_blocks)])

        self.attention_blocks = nn.ModuleList(
            [AttentionBlock2D(channels, config) for _ in range(config.num_res_blocks)]
        )

        self.gradient_checkpointing = False

    def forward(self, x, cond_embeds, encoder_hidden_states):
        if self.downsample is not None:
            x = self.downsample(x)

        for res_block, attention_block in zip(self.res_blocks, self.attention_blocks):
            if self.training and self.gradient_checkpointing:
                res_block_ = lambda *args: checkpoint(res_block, *args)
                attention_block_ = lambda *args: checkpoint(attention_block, *args)
            else:
                res_block_ = res_block
                attention_block_ = attention_block

            x = res_block_(x, cond_embeds)
            x = attention_block_(x, encoder_hidden_states)

        return x


class UpsampleBlock(nn.Module):
    def __init__(
        self,
        channels: int,
        config: MaskGiTUViT_v2Config,
    ):
        super().__init__()

        self.res_blocks = nn.ModuleList([ResBlock(channels, config) for i in range(config.num_res_blocks)])

        self.attention_blocks = nn.ModuleList(
            [AttentionBlock2D(channels, config) for _ in range(config.num_res_blocks)]
        )

        if config.force_down_up_sample:
            self.upsample = nn.Sequential(
                Norm2D(channels, config),
                nn.ConvTranspose2d(channels, channels, kernel_size=2, stride=2, bias=config.use_bias),
            )
        else:
            self.upsample = None

        self.gradient_checkpointing = False

    def forward(self, x, cond_embeds, encoder_hidden_states):
        for res_block, attention_block in zip(self.res_blocks, self.attention_blocks):
            if self.training and self.gradient_checkpointing:
                res_block_ = lambda *args: checkpoint(res_block, *args)
                attention_block_ = lambda *args: checkpoint(attention_block, *args)
            else:
                res_block_ = res_block
                attention_block_ = attention_block

            x = res_block_(x, cond_embeds)
            x = attention_block_(x, encoder_hidden_states)

        if self.upsample is not None:
            x = self.upsample(x)

        return x


class ResBlock(nn.Module):
    def __init__(
        self,
        channels,
        config: MaskGiTUViT_v2Config,
        res_ffn_factor=4,
    ):
        super().__init__()
        self.depthwise = nn.Conv2d(
            channels,
            channels,
            kernel_size=3,
            padding=1,
            groups=channels,
            bias=config.use_bias,
        )
        self.norm = Norm2D(channels, config)
        self.channelwise = nn.Sequential(
            nn.Linear(channels, int(channels * res_ffn_factor), bias=config.use_bias),
            nn.GELU(),
            GlobalResponseNorm(int(channels * res_ffn_factor)),
            nn.Dropout(config.hidden_dropout),
            nn.Linear(int(channels * res_ffn_factor), channels, bias=config.use_bias),
        )
        self.adaLN_modulation = AdaLNModulation(channels, config)

    def forward(self, x, cond_embeds):
        x_res = x
        x = self.norm(self.depthwise(x)).permute(0, 2, 3, 1)
        x = self.channelwise(x).permute(0, 3, 1, 2)
        x = x + x_res
        x = self.adaLN_modulation(x, cond_embeds)
        return x


# norm blocks


class Norm2D(nn.Module):
    def __init__(self, dim, config: MaskGiTUViT_v2Config):
        super().__init__()
        self.norm = Norm(dim, config)

    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        x, _ = self.norm(x)
        x = x.permute(0, 3, 1, 2)
        return x


def Norm(dim, config: MaskGiTUViT_v2Config):
    if config.norm_type == "layernorm":
        return LayerNorm(dim, config)
    elif config.norm_type == "rmsnorm":
        return RMSNorm(dim, config)
    else:
        assert False


class RMSNorm(nn.Module):
    def __init__(self, dim, config: MaskGiTUViT_v2Config):
        super().__init__()

        self.config = config

        if isinstance(dim, numbers.Integral):
            dim = (dim,)

        self.dim = torch.Size(dim)

        if self.config.ln_elementwise_affine:
            self.weight = nn.Parameter(torch.ones(dim))
        else:
            self.weight = None

    def forward(self, input, residual=None):
        if self.config.use_fused_residual_norm:
            if dropout_add_rms_norm is None:
                raise ImportError("Please install flash_attn to use fused rms norm")

            return dropout_add_rms_norm(
                input, residual, self.weight, None, dropout_p=0.0, epsilon=self.config.layer_norm_eps, prenorm=True
            )
        else:
            return unfused_rms_norm(input, residual, self.weight, self.config.layer_norm_eps)


def unfused_rms_norm(input, residual, weight, eps):
    if residual is not None:
        input = input + residual

    prenorm_residual = input

    input_dtype = input.dtype
    variance = input.to(torch.float32).pow(2).mean(-1, keepdim=True)
    input = input * torch.rsqrt(variance + eps)

    if weight is not None:
        # convert into half-precision if necessary
        if weight.dtype in [torch.float16, torch.bfloat16]:
            input = input.to(weight.dtype)
        input = input * weight
    else:
        input = input.to(input_dtype)

    return input, prenorm_residual


class LayerNorm(nn.Module):
    def __init__(self, dim, config: MaskGiTUViT_v2Config):
        super().__init__()

        self.config = config

        if isinstance(dim, numbers.Integral):
            dim = (dim,)

        self.dim = torch.Size(dim)

        if self.config.ln_elementwise_affine:
            self.weight = nn.Parameter(torch.ones(dim))
            self.bias = nn.Parameter(torch.zeros(dim)) if self.config.use_bias else None
        else:
            self.weight = None
            self.bias = None

    def forward(self, input, residual=None):
        if self.config.use_fused_residual_norm:
            if dropout_add_layer_norm is None:
                raise ImportError("Please install flash_attn to use fused layer norm")

            return dropout_add_layer_norm(
                x0=input,
                residual=residual,
                weight=self.weight,
                bias=self.bias,
                epsilon=self.config.layer_norm_eps,
                dropout_p=0.0,
                prenorm=True,
            )
        else:
            return unfused_layer_norm(input, residual, self.dim, self.weight, self.bias, self.config.layer_norm_eps)


def unfused_layer_norm(input, residual, dim, weight, bias, eps):
    if residual is not None:
        input = input + residual

    prenorm_residual = input

    input = F.layer_norm(input, dim, weight, bias, eps)

    return input, prenorm_residual


class GlobalResponseNorm(nn.Module):
    # Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
    def __init__(self, dim):
        super().__init__()
        self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
        self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))

    def forward(self, x):
        Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
        Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
        return self.gamma * (x * Nx) + self.beta + x


# attention/transformer blocks


class TransformerLayer(nn.Module):
    def __init__(self, config: MaskGiTUViT_v2Config):
        super().__init__()

        self.attn_layer_norm = Norm(config.hidden_size, config)

        self.self_attn_adaLN_modulation = AdaLNModulation(config.hidden_size, config)
        self.attention = Attention(config.hidden_size, config.hidden_size, config.num_attention_heads, config)

        self.crossattn_layer_norm = Norm(config.hidden_size, config)
        self.crossattention = Attention(
            config.hidden_size, config.hidden_size, config.num_attention_heads, config
        )
        self.cross_attn_adaLN_modulation = AdaLNModulation(config.hidden_size, config)

        self.ffn = FeedForward(config)

    def forward(self, hidden_states, encoder_hidden_states, cond_embeds, residual=None):
        hidden_states, residual = self.attn_layer_norm(hidden_states, residual=residual)

        hidden_states = self.self_attn_adaLN_modulation(hidden_states, cond_embeds)

        hidden_states = self.attention(hidden_states, hidden_states)

        hidden_states, residual = self.crossattn_layer_norm(hidden_states, residual=residual)

        hidden_states = self.cross_attn_adaLN_modulation(hidden_states, cond_embeds)

        hidden_states = self.crossattention(
            hidden_states,
            encoder_hidden_states,
        )

        hidden_states, residual = self.ffn(hidden_states, cond_embeds=cond_embeds, residual=residual)

        return hidden_states, residual


class AttentionBlock2D(nn.Module):
    def __init__(self, hidden_size: int, config: MaskGiTUViT_v2Config):
        super().__init__()

        if config.hidden_size != hidden_size:
            self.kv_mapper = nn.Linear(config.hidden_size, hidden_size, bias=config.use_bias)
        else:
            self.kv_mapper = None

        encoder_hidden_size = hidden_size

        # NOTE: this is actually a cross attention layer, but keeping the naming from v1 to
        # keep the state dicts compatible
        self.attn_layer_norm = Norm(hidden_size, config)
        self.attention = Attention(hidden_size, encoder_hidden_size, config.block_num_heads, config)

        self.crossattn_layer_norm = Norm(hidden_size, config)
        self.crossattention = Attention(hidden_size, encoder_hidden_size, config.block_num_heads, config)

    def forward(self, hidden_states, encoder_hidden_states):
        batch_size, channels, height, width = hidden_states.shape
        hidden_states = hidden_states.view(batch_size, channels, height * width).permute(0, 2, 1)

        if self.kv_mapper is not None:
            encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states))

        # NOTE: This is actually a cross attention layer
        hidden_states, residual = self.attn_layer_norm(hidden_states)
        hidden_states = self.attention(hidden_states, encoder_hidden_states)

        hidden_states, residual = self.crossattn_layer_norm(hidden_states, residual)
        hidden_states = self.crossattention(hidden_states, encoder_hidden_states)
        hidden_states = hidden_states + residual

        hidden_states = hidden_states.permute(0, 2, 1).view(batch_size, channels, height, width)

        return hidden_states


class Attention(nn.Module):
    def __init__(self, hidden_size: int, context_dim: int, num_heads: int, config: MaskGiTUViT_v2Config):
        super().__init__()

        self.config = config
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = self.hidden_size // num_heads

        if self.hidden_size % self.num_heads != 0:
            raise ValueError(
                f"self.hidden_size: {self.hidden_size} must be divisible by self.num_heads: {self.num_heads}"
            )

        self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())

        self.query = nn.Linear(self.hidden_size, self.hidden_size, bias=self.config.use_bias)

        self.key = nn.Linear(context_dim, self.hidden_size, bias=self.config.use_bias)
        self.value = nn.Linear(context_dim, self.hidden_size, bias=self.config.use_bias)

        self.out = nn.Linear(self.hidden_size, self.hidden_size, bias=self.config.use_bias)
        self.dropout = nn.Dropout(self.config.attention_dropout)

        self.use_memory_efficient_attention_xformers = False
        self.xformers_attention_op = None

    def set_use_memory_efficient_attention_xformers(
        self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
    ):
        if use_memory_efficient_attention_xformers and not is_xformers_available:
            raise ImportError("Please install xformers to use memory efficient attention")
        self.use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
        self.xformers_attention_op = attention_op

    def forward(self, hidden_states, context):
        batch, q_seq_len, _ = hidden_states.shape
        kv_seq_len = context.shape[1]

        query = self.query(hidden_states)
        key = self.key(context)
        value = self.value(context)

        query = query.view(batch, q_seq_len, self.num_heads, self.head_dim)  # (B, T, nh, hs)
        key = key.view(batch, kv_seq_len, self.num_heads, self.head_dim)  # (B, T, nh, hs)
        value = value.view(batch, kv_seq_len, self.num_heads, self.head_dim)  # (B, T, nh, hs)

        if self.use_memory_efficient_attention_xformers:
            attn_output = xops.memory_efficient_attention(
                query,
                key,
                value,
                op=self.xformers_attention_op,
                p=self.config.attention_dropout if self.training else 0.0,
            )
            attn_output = attn_output.view(batch, q_seq_len, self.hidden_size)
        else:
            attn_output = self.attention(query, key, value)

        attn_output = self.out(attn_output)
        return attn_output

    def attention(self, query, key, value, attention_mask=None):
        batch, seq_len = query.shape[:2]
        kv_seq_len = key.shape[1]
        query, key, value = map(lambda t: t.transpose(1, 2).contiguous(), (query, key, value))  # (B, nh, T, hs)

        attn_weights = torch.baddbmm(
            input=torch.zeros(batch * self.num_heads, seq_len, kv_seq_len, dtype=query.dtype, device=query.device),
            batch1=query.view(batch * self.num_heads, seq_len, self.head_dim),
            batch2=key.view(batch * self.num_heads, kv_seq_len, self.head_dim).transpose(1, 2),
            alpha=1 / self.scale_attn,
        )
        attn_weights = attn_weights.view(batch, self.num_heads, seq_len, kv_seq_len)  # -1 is kv_seq_len
        # Apply the attention mask
        if attention_mask is not None:
            attn_weights = torch.masked_fill(attn_weights, attention_mask, torch.finfo(query.dtype).min)
        attn_weights = F.softmax(attn_weights, dim=-1)
        attn_weights = self.dropout(attn_weights)
        attn_output = torch.matmul(attn_weights, value)  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        # re-assemble all head outputs side by side
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch, seq_len, self.hidden_size)
        return attn_output


def FeedForward(config: MaskGiTUViT_v2Config):
    if config.use_fused_mlp:
        return FusedGeLUFeedForward(config)
    else:
        return GLUFeedForward(config)


class GLUFeedForward(nn.Module):
    def __init__(self, config: MaskGiTUViT_v2Config):
        super().__init__()
        self.pre_mlp_layer_norm = LayerNorm(config.hidden_size, config)
        self.adaLN_modulation = AdaLNModulation(config.hidden_size, config)
        self.wi_0 = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.use_bias)
        self.wi_1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.use_bias)
        self.dropout = nn.Dropout(config.hidden_dropout)
        self.wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.use_bias)

    def forward(self, hidden_states, cond_embeds, residual=None):
        hidden_states, residual = self.pre_mlp_layer_norm(hidden_states, residual=residual)

        hidden_states = self.adaLN_modulation(hidden_states, cond_embeds)

        hidden_gelu = F.gelu(self.wi_0(hidden_states))

        hidden_linear = self.wi_1(hidden_states)

        hidden_states = hidden_gelu * hidden_linear

        hidden_states = self.dropout(hidden_states)

        hidden_states = self.wo(hidden_states)

        return hidden_states, residual


class FusedGeLUFeedForward(nn.Module):
    def __init__(self, config: MaskGiTUViT_v2Config):
        super().__init__()
        self.pre_mlp_layer_norm = LayerNorm(config.hidden_size, config)
        self.adaLN_modulation = AdaLNModulation(config.hidden_size, config)
        self.wi_0 = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.use_bias)
        self.dropout = nn.Dropout(config.hidden_dropout)
        self.wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.use_bias)

    def forward(self, hidden_states, cond_embeds, residual=None):
        if fused_mlp_func is None:
            raise ImportError("Please install flash_attn to use fused mlp")

        hidden_states, residual = self.pre_mlp_layer_norm(hidden_states, residual=residual)

        hidden_states = self.adaLN_modulation(hidden_states, cond_embeds)

        dtype = hidden_states.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype()
        cuda_ver = tuple(map(int, torch.version.cuda.split(".")))

        if torch.cuda.get_device_capability("cuda") == (9, 0):
            heuristic = -1
        elif cuda_ver >= (11, 8):
            heuristic = 0
        elif dtype == torch.float16:
            heuristic = 1
        else:
            heuristic = -1

        hidden_states = fused_mlp_func(
            hidden_states,
            self.wi_0.weight,
            self.wo.weight,
            self.wi_0.bias,
            self.wo.bias,
            activation="gelu_approx",
            save_pre_act=self.training,
            return_residual=False,
            checkpoint_lvl=0,
            heuristic=heuristic,
        )

        return hidden_states, residual


# misc blocks


class ConvMlmLayer(nn.Module):
    def __init__(self, config: MaskGiTUViT_v2Config):
        super().__init__()
        self.config = config
        self.conv1 = nn.Conv2d(
            self.config.block_out_channels[0], self.config.in_channels, kernel_size=1, bias=self.config.use_bias
        )
        self.layer_norm = Norm2D(self.config.in_channels, config)
        self.conv2 = nn.Conv2d(
            self.config.in_channels, self.config.codebook_size, kernel_size=1, bias=self.config.use_bias
        )

    def forward(self, hidden_states):
        batch_size, seq_length, hidden_size = hidden_states.shape
        resolution = int(seq_length**0.5)
        hidden_states = hidden_states.view(batch_size, resolution, resolution, hidden_size).permute(0, 3, 1, 2)
        hidden_states = self.conv1(hidden_states)
        hidden_states = self.layer_norm(hidden_states)
        logits = self.conv2(hidden_states)
        logits = logits.permute(0, 2, 3, 1).view(batch_size, -1, self.config.codebook_size)
        return logits


class AdaLNModulation(nn.Module):
    def __init__(self, hidden_size: int, config: MaskGiTUViT_v2Config):
        super().__init__()
        self.mapper = nn.Linear(config.hidden_size, hidden_size * 2, bias=config.use_bias)

    def forward(self, hidden_states, cond_embeds):
        cond_embeds = F.silu(cond_embeds)
        scale, shift = self.mapper(cond_embeds).chunk(2, dim=1)
        if hidden_states.dim() > 3:
            scale, shift = scale[:, :, None, None], shift[:, :, None, None]
        else:
            scale, shift = scale[:, None], shift[:, None]
        return hidden_states * (1 + scale) + shift
