muse/modeling_transformer_v2.py (740 lines of code) (raw):

# coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # This file is heavily inspired by the original implementation from https://github.com/lucidrains/muse-maskgit-pytorch import dataclasses import math import numbers import warnings from dataclasses import dataclass from typing import Callable, Optional, Tuple import numpy as np import torch import torch.nn.functional as F from torch import nn from torch.utils.checkpoint import checkpoint from .modeling_utils import ConfigMixin, ModelMixin from .sampling import cosine_schedule, mask_by_random_topk try: import xformers.ops as xops is_xformers_available = True except ImportError: is_xformers_available = False try: from flash_attn.ops.rms_norm import dropout_add_rms_norm except ImportError: dropout_add_rms_norm = None try: from flash_attn.ops.layer_norm import dropout_add_layer_norm except ImportError: dropout_add_layer_norm = None try: from flash_attn.ops.fused_dense import fused_mlp_func except ImportError: fused_mlp_func = None warnings.simplefilter("once", UserWarning) def sinusoidal_encode(features, embedding_dim, max_positions=10000): half_dim = embedding_dim // 2 emb = math.log(max_positions) / half_dim emb = ( torch.arange( 0, half_dim, device=features.device, dtype=torch.float32, ) .mul(-emb) .exp() ) emb = features[:, None] * emb[None, :] emb = torch.cat([emb.cos(), emb.sin()], dim=1) if embedding_dim % 2 == 1: # zero pad emb = nn.functional.pad(emb, (0, 1), mode="constant") return emb @dataclass class MaskGiTUViT_v2Config: # global config hidden_size: int = 1024 use_bias: bool = False hidden_dropout: float = 0.0 # conditioning dimensions cond_embed_dim: int = 768 micro_cond_encode_dim: int = 256 micro_cond_embed_dim: int = 1280 encoder_hidden_size: int = 768 # num tokens vocab_size: int = 8256 # codebook_size + 1 (for the mask token) rounded mask_token_id: int = 8255 codebook_size: int = 8192 # `DownsampleBlock` and `UpsampleBlock` in_channels: int = 768 block_out_channels: Tuple[int] = (768,) num_res_blocks: int = 3 force_down_up_sample: bool = False block_num_heads: int = 12 # `TransformerLayer` num_hidden_layers: int = 22 num_attention_heads: int = 16 # `Attention` attention_dropout: float = 0.0 # `FeedForward` intermediate_size: int = 2816 use_fused_mlp: bool = False # `Norm` norm_type: str = "rmsnorm" layer_norm_eps: float = 1e-6 ln_elementwise_affine: bool = True use_fused_residual_norm: bool = False # Legacy: kept for compatibility with pipeline add_cond_embeds: bool = True add_micro_cond_embeds: bool = True def config_from_legacy_kwargs(**kwargs): if "block_num_heads" in kwargs: if isinstance(kwargs["block_num_heads"], (tuple, list)): assert len(kwargs["block_num_heads"]) == 1 kwargs["block_num_heads"] = kwargs["block_num_heads"][0] elif isinstance(kwargs["block_num_heads"], int): ... else: assert False config = {} # select only values that are expected to be in the config for field in dataclasses.fields(MaskGiTUViT_v2Config): if field.name in kwargs: config[field.name] = kwargs[field.name] # set default config values config = MaskGiTUViT_v2Config(**config) config.block_out_channels = list(config.block_out_channels) return config class MaskGiTUViT_v2(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True def __init__(self, **kwargs): super().__init__() config = config_from_legacy_kwargs(**kwargs) self.register_to_config(**dataclasses.asdict(config)) self.register_to_config(mask_token_id=self.config.vocab_size - 1) # TODO: Allow enabling fused norm using a function (like we do for xformers attention) if self.config.use_fused_residual_norm and dropout_add_layer_norm is None: warnings.warn("Cannot use fused layer norm. Please install flash_attn. Falling back to unfused layer norm", UserWarning) self.register_to_config(use_fused_residual_norm=False) assert len(self.config.block_out_channels) == 1 # Legacy: kept for compatibility with pipeline self.output_size = self.config.codebook_size self.encoder_proj = nn.Linear( self.config.encoder_hidden_size, self.config.hidden_size, bias=self.config.use_bias ) self.encoder_proj_layer_norm = Norm(self.config.hidden_size, self.config) self.embed = ConvEmbed(self.config) self.cond_embed = nn.Sequential( nn.Linear( self.config.micro_cond_embed_dim + self.config.cond_embed_dim, self.config.hidden_size, bias=self.config.use_bias, ), nn.SiLU(), nn.Linear(self.config.hidden_size, self.config.hidden_size, bias=self.config.use_bias), ) self.down_blocks = nn.ModuleList([DownsampleBlock(self.config.block_out_channels[0], self.config)]) self.project_to_hidden_norm = Norm(self.config.block_out_channels[-1], self.config) self.project_to_hidden = nn.Linear( self.config.block_out_channels[-1], self.config.hidden_size, bias=self.config.use_bias ) self.transformer_layers = nn.ModuleList( [TransformerLayer(self.config) for _ in range(self.config.num_hidden_layers)] ) self.project_from_hidden_norm = Norm(self.config.hidden_size, self.config) self.project_from_hidden = nn.Linear( self.config.hidden_size, self.config.block_out_channels[-1], bias=self.config.use_bias ) self.up_blocks = nn.ModuleList([UpsampleBlock(self.config.block_out_channels[0], self.config)]) self.mlm_layer = ConvMlmLayer(self.config) self.gradient_checkpointing = False # --- WEIGHT INIT --- self.apply(self._init_weights) # General init nn.init.xavier_uniform_(self.embed.conv.weight, 0.02) # inputs nn.init.normal_(self.embed.embeddings.weight, std=np.sqrt(1 / self.config.vocab_size)) nn.init.constant_(self.mlm_layer.conv1.weight, 0) # output self.mlm_layer.conv2.weight.data = self.embed.embeddings.weight.data[ : self.config.codebook_size, :, None, None ].clone() # init AdaLNModulation.mapper layers to 0 for m in self.modules(): if isinstance(m, AdaLNModulation): nn.init.constant_(m.mapper.weight, 0) if hasattr(m, "bias") and m.bias is not None: nn.init.constant_(m.bias, 0) def _init_weights(self, module): """ Initialize the weights according to the original implementation. https://github.com/google-research/maskgit/blob/main/maskgit/nets/maskgit_transformer.py#L37 """ if isinstance(module, (nn.Linear, nn.Conv2d)): nn.init.trunc_normal_(module.weight, std=0.02) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): nn.init.trunc_normal_(module.weight, std=0.02) elif isinstance(module, (LayerNorm, RMSNorm)): if hasattr(module, "weight") and module.weight is not None: module.weight.data.fill_(1.0) if hasattr(module, "bias") and module.bias is not None: module.bias.data.zero_() def forward( self, input_ids, encoder_hidden_states, cond_embeds, micro_conds, labels=None, label_smoothing=0.0, loss_weight=None, ): encoder_hidden_states = self.encoder_proj(encoder_hidden_states) encoder_hidden_states, _ = self.encoder_proj_layer_norm(encoder_hidden_states) micro_cond_embeds = sinusoidal_encode(micro_conds.flatten(), self.config.micro_cond_encode_dim) micro_cond_embeds = micro_cond_embeds.reshape((input_ids.shape[0], -1)) cond_embeds = torch.cat([cond_embeds, micro_cond_embeds], dim=1) cond_embeds = cond_embeds.to(dtype=self.dtype) cond_embeds = self.cond_embed(cond_embeds).to(encoder_hidden_states.dtype) hidden_states = self.embed(input_ids) hidden_states = self.down_blocks[0]( hidden_states, cond_embeds=cond_embeds, encoder_hidden_states=encoder_hidden_states ) batch_size, channels, height, width = hidden_states.shape hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels) hidden_states, _ = self.project_to_hidden_norm(hidden_states) hidden_states = self.project_to_hidden(hidden_states) transformer_residual = None for layer in self.transformer_layers: if self.training and self.gradient_checkpointing: layer_ = lambda *args: checkpoint(layer, *args) else: layer_ = layer hidden_states, transformer_residual = layer_( hidden_states, encoder_hidden_states, cond_embeds, transformer_residual, ) hidden_states = hidden_states + transformer_residual hidden_states, _ = self.project_from_hidden_norm(hidden_states) hidden_states = self.project_from_hidden(hidden_states) hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) assert len(self.up_blocks) == 1 hidden_states = self.up_blocks[0]( hidden_states, cond_embeds=cond_embeds, encoder_hidden_states=encoder_hidden_states ) batch_size, channels, height, width = hidden_states.shape hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels) logits = self.mlm_layer(hidden_states) if labels is not None: reduction = "none" if loss_weight is not None else "mean" loss = F.cross_entropy( logits.view(-1, self.codebook_size), labels.view(-1), ignore_index=-100, label_smoothing=label_smoothing, reduction=reduction, ) if loss_weight is not None: loss_weight = loss_weight.view(-1) loss = ((loss * loss_weight).sum(dim=-1) / loss_weight.sum(dim=-1)).mean() return logits, loss return logits def _set_gradient_checkpointing(self, module, value=False): self.gradient_checkpointing = value if isinstance(module, (DownsampleBlock, UpsampleBlock)): module.gradient_checkpointing = value # Legacy: kept for compatibility with pipeline def generate(self): assert False def generate2( self, encoder_hidden_states: torch.FloatTensor, cond_embeds: torch.FloatTensor, micro_conds: torch.FloatTensor, empty_embeds: torch.FloatTensor, empty_cond_embeds: torch.FloatTensor, input_ids: torch.LongTensor = None, negative_embeds: torch.FloatTensor = None, negative_cond_embeds: torch.FloatTensor = None, temperature=1.0, timesteps=18, # ideal number of steps is 18 in maskgit paper guidance_scale=0, guidance_schedule=None, noise_schedule=cosine_schedule, generator: torch.Generator = None, return_intermediate=False, seq_len=None, use_tqdm=None, # Legacy: kept for compatibility with pipeline topk_filter_thres=None, noise_type=None, predict_all_tokens=None, ): batch_size = encoder_hidden_states.shape[0] if seq_len is None: seq_len = 256 shape = (batch_size, seq_len) if isinstance(temperature, tuple): temperatures = torch.linspace(temperature[0], temperature[1], timesteps) else: temperatures = torch.linspace(temperature, 0.01, timesteps) if input_ids is None: input_ids = torch.ones(shape, dtype=torch.long, device=self.device) * self.config.mask_token_id if return_intermediate: intermediate = [] if guidance_schedule == "linear": guidance_scales = torch.linspace(0, guidance_scale, timesteps) elif guidance_schedule == "cosine": guidance_scales = [] for step in range(timesteps): ratio = 1.0 * (step + 1) / timesteps scale = cosine_schedule(torch.tensor(1 - ratio)) * guidance_scale guidance_scales.append(scale.floor()) guidance_scales = torch.tensor(guidance_scales) else: guidance_scales = torch.ones(timesteps) * guidance_scale if micro_conds.shape[0] == 1: micro_conds = micro_conds.repeat(batch_size, 1).to(input_ids.device) if guidance_scale > 0: # encoder_hidden_states if negative_embeds is None: uncond_encoder_states = empty_embeds else: uncond_encoder_states = negative_embeds if uncond_encoder_states.shape[0] == 1: uncond_encoder_states = uncond_encoder_states.expand(batch_size, -1, -1) encoder_hidden_states = torch.cat([encoder_hidden_states, uncond_encoder_states]) # cond_embeds if negative_cond_embeds is None: uncond_embeds = empty_cond_embeds else: uncond_embeds = negative_cond_embeds if uncond_embeds.shape[0] == 1: uncond_embeds = uncond_embeds.expand(batch_size, -1) cond_embeds = torch.cat([cond_embeds, uncond_embeds]) # micro_conds micro_conds = torch.cat([micro_conds, micro_conds], dim=0) if use_tqdm: from tqdm.auto import tqdm timesteps_iter = tqdm(range(timesteps)) else: timesteps_iter = range(timesteps) for step in timesteps_iter: if guidance_scale > 0: model_input = torch.cat([input_ids] * 2) model_output = self( model_input, micro_conds=micro_conds, cond_embeds=cond_embeds, encoder_hidden_states=encoder_hidden_states, ) if guidance_scale > 0: cond_logits, uncond_logits = model_output.chunk(2) cond_logits = cond_logits[..., : self.config.codebook_size] uncond_logits = uncond_logits[..., : self.config.codebook_size] logits = uncond_logits + guidance_scales[step] * (cond_logits - uncond_logits) else: logits = model_output logits = logits[..., : self.config.codebook_size] probs = logits.softmax(dim=-1) sampled = probs.reshape(-1, logits.size(-1)) sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1]) if return_intermediate: intermediate.append(sampled_ids) unknown_map = input_ids == self.config.mask_token_id sampled_ids = torch.where(unknown_map, sampled_ids, input_ids) # Defines the mask ratio for the next round. The number to mask out is # determined by mask_ratio * unknown_number_in_the_beginning. ratio = 1.0 * (step + 1) / timesteps mask_ratio = noise_schedule(torch.tensor(ratio)) # Gets mask lens for each sample in the batch according to the mask ratio. mask_len = (seq_len * mask_ratio).floor().unsqueeze(0).to(logits.device) # Keeps at least one of prediction in this round and also masks out at least # one and for the next iteration mask_len = torch.max( torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len), ) selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None]) selected_probs = selected_probs.squeeze(-1) # Ignores the tokens given in the input by overwriting their confidence. selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max) temperature = temperatures[step] masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator) # Masks tokens with lower confidence. input_ids = torch.where(masking, self.config.mask_token_id, sampled_ids) if return_intermediate: return sampled_ids, intermediate return sampled_ids # embedding blocks class ConvEmbed(nn.Module): def __init__(self, config: MaskGiTUViT_v2Config): super().__init__() self.embeddings = nn.Embedding(config.vocab_size, config.in_channels) self.layer_norm = Norm(config.in_channels, config) self.conv = nn.Conv2d(config.in_channels, config.block_out_channels[0], kernel_size=1, bias=config.use_bias) def forward(self, input_ids): batch_size, seq_length = input_ids.shape height, width = int(seq_length**0.5), int(seq_length**0.5) input_ids = input_ids.view(-1, height, width) embeddings = self.embeddings(input_ids) embeddings, _ = self.layer_norm(embeddings) embeddings = embeddings.permute(0, 3, 1, 2) embeddings = self.conv(embeddings) return embeddings # down/upsample blocks class DownsampleBlock(nn.Module): def __init__(self, channels, config: MaskGiTUViT_v2Config): super().__init__() if config.force_down_up_sample: self.downsample = nn.Sequential( Norm2D(channels, config), nn.Conv2d(channels, channels, kernel_size=2, stride=2, bias=config.use_bias), ) else: self.downsample = None self.res_blocks = nn.ModuleList([ResBlock(channels, config) for _ in range(config.num_res_blocks)]) self.attention_blocks = nn.ModuleList( [AttentionBlock2D(channels, config) for _ in range(config.num_res_blocks)] ) self.gradient_checkpointing = False def forward(self, x, cond_embeds, encoder_hidden_states): if self.downsample is not None: x = self.downsample(x) for res_block, attention_block in zip(self.res_blocks, self.attention_blocks): if self.training and self.gradient_checkpointing: res_block_ = lambda *args: checkpoint(res_block, *args) attention_block_ = lambda *args: checkpoint(attention_block, *args) else: res_block_ = res_block attention_block_ = attention_block x = res_block_(x, cond_embeds) x = attention_block_(x, encoder_hidden_states) return x class UpsampleBlock(nn.Module): def __init__( self, channels: int, config: MaskGiTUViT_v2Config, ): super().__init__() self.res_blocks = nn.ModuleList([ResBlock(channels, config) for i in range(config.num_res_blocks)]) self.attention_blocks = nn.ModuleList( [AttentionBlock2D(channels, config) for _ in range(config.num_res_blocks)] ) if config.force_down_up_sample: self.upsample = nn.Sequential( Norm2D(channels, config), nn.ConvTranspose2d(channels, channels, kernel_size=2, stride=2, bias=config.use_bias), ) else: self.upsample = None self.gradient_checkpointing = False def forward(self, x, cond_embeds, encoder_hidden_states): for res_block, attention_block in zip(self.res_blocks, self.attention_blocks): if self.training and self.gradient_checkpointing: res_block_ = lambda *args: checkpoint(res_block, *args) attention_block_ = lambda *args: checkpoint(attention_block, *args) else: res_block_ = res_block attention_block_ = attention_block x = res_block_(x, cond_embeds) x = attention_block_(x, encoder_hidden_states) if self.upsample is not None: x = self.upsample(x) return x class ResBlock(nn.Module): def __init__( self, channels, config: MaskGiTUViT_v2Config, res_ffn_factor=4, ): super().__init__() self.depthwise = nn.Conv2d( channels, channels, kernel_size=3, padding=1, groups=channels, bias=config.use_bias, ) self.norm = Norm2D(channels, config) self.channelwise = nn.Sequential( nn.Linear(channels, int(channels * res_ffn_factor), bias=config.use_bias), nn.GELU(), GlobalResponseNorm(int(channels * res_ffn_factor)), nn.Dropout(config.hidden_dropout), nn.Linear(int(channels * res_ffn_factor), channels, bias=config.use_bias), ) self.adaLN_modulation = AdaLNModulation(channels, config) def forward(self, x, cond_embeds): x_res = x x = self.norm(self.depthwise(x)).permute(0, 2, 3, 1) x = self.channelwise(x).permute(0, 3, 1, 2) x = x + x_res x = self.adaLN_modulation(x, cond_embeds) return x # norm blocks class Norm2D(nn.Module): def __init__(self, dim, config: MaskGiTUViT_v2Config): super().__init__() self.norm = Norm(dim, config) def forward(self, x): x = x.permute(0, 2, 3, 1) x, _ = self.norm(x) x = x.permute(0, 3, 1, 2) return x def Norm(dim, config: MaskGiTUViT_v2Config): if config.norm_type == "layernorm": return LayerNorm(dim, config) elif config.norm_type == "rmsnorm": return RMSNorm(dim, config) else: assert False class RMSNorm(nn.Module): def __init__(self, dim, config: MaskGiTUViT_v2Config): super().__init__() self.config = config if isinstance(dim, numbers.Integral): dim = (dim,) self.dim = torch.Size(dim) if self.config.ln_elementwise_affine: self.weight = nn.Parameter(torch.ones(dim)) else: self.weight = None def forward(self, input, residual=None): if self.config.use_fused_residual_norm: if dropout_add_rms_norm is None: raise ImportError("Please install flash_attn to use fused rms norm") return dropout_add_rms_norm( input, residual, self.weight, None, dropout_p=0.0, epsilon=self.config.layer_norm_eps, prenorm=True ) else: return unfused_rms_norm(input, residual, self.weight, self.config.layer_norm_eps) def unfused_rms_norm(input, residual, weight, eps): if residual is not None: input = input + residual prenorm_residual = input input_dtype = input.dtype variance = input.to(torch.float32).pow(2).mean(-1, keepdim=True) input = input * torch.rsqrt(variance + eps) if weight is not None: # convert into half-precision if necessary if weight.dtype in [torch.float16, torch.bfloat16]: input = input.to(weight.dtype) input = input * weight else: input = input.to(input_dtype) return input, prenorm_residual class LayerNorm(nn.Module): def __init__(self, dim, config: MaskGiTUViT_v2Config): super().__init__() self.config = config if isinstance(dim, numbers.Integral): dim = (dim,) self.dim = torch.Size(dim) if self.config.ln_elementwise_affine: self.weight = nn.Parameter(torch.ones(dim)) self.bias = nn.Parameter(torch.zeros(dim)) if self.config.use_bias else None else: self.weight = None self.bias = None def forward(self, input, residual=None): if self.config.use_fused_residual_norm: if dropout_add_layer_norm is None: raise ImportError("Please install flash_attn to use fused layer norm") return dropout_add_layer_norm( x0=input, residual=residual, weight=self.weight, bias=self.bias, epsilon=self.config.layer_norm_eps, dropout_p=0.0, prenorm=True, ) else: return unfused_layer_norm(input, residual, self.dim, self.weight, self.bias, self.config.layer_norm_eps) def unfused_layer_norm(input, residual, dim, weight, bias, eps): if residual is not None: input = input + residual prenorm_residual = input input = F.layer_norm(input, dim, weight, bias, eps) return input, prenorm_residual class GlobalResponseNorm(nn.Module): # Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105 def __init__(self, dim): super().__init__() self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) def forward(self, x): Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) return self.gamma * (x * Nx) + self.beta + x # attention/transformer blocks class TransformerLayer(nn.Module): def __init__(self, config: MaskGiTUViT_v2Config): super().__init__() self.attn_layer_norm = Norm(config.hidden_size, config) self.self_attn_adaLN_modulation = AdaLNModulation(config.hidden_size, config) self.attention = Attention(config.hidden_size, config.hidden_size, config.num_attention_heads, config) self.crossattn_layer_norm = Norm(config.hidden_size, config) self.crossattention = Attention( config.hidden_size, config.hidden_size, config.num_attention_heads, config ) self.cross_attn_adaLN_modulation = AdaLNModulation(config.hidden_size, config) self.ffn = FeedForward(config) def forward(self, hidden_states, encoder_hidden_states, cond_embeds, residual=None): hidden_states, residual = self.attn_layer_norm(hidden_states, residual=residual) hidden_states = self.self_attn_adaLN_modulation(hidden_states, cond_embeds) hidden_states = self.attention(hidden_states, hidden_states) hidden_states, residual = self.crossattn_layer_norm(hidden_states, residual=residual) hidden_states = self.cross_attn_adaLN_modulation(hidden_states, cond_embeds) hidden_states = self.crossattention( hidden_states, encoder_hidden_states, ) hidden_states, residual = self.ffn(hidden_states, cond_embeds=cond_embeds, residual=residual) return hidden_states, residual class AttentionBlock2D(nn.Module): def __init__(self, hidden_size: int, config: MaskGiTUViT_v2Config): super().__init__() if config.hidden_size != hidden_size: self.kv_mapper = nn.Linear(config.hidden_size, hidden_size, bias=config.use_bias) else: self.kv_mapper = None encoder_hidden_size = hidden_size # NOTE: this is actually a cross attention layer, but keeping the naming from v1 to # keep the state dicts compatible self.attn_layer_norm = Norm(hidden_size, config) self.attention = Attention(hidden_size, encoder_hidden_size, config.block_num_heads, config) self.crossattn_layer_norm = Norm(hidden_size, config) self.crossattention = Attention(hidden_size, encoder_hidden_size, config.block_num_heads, config) def forward(self, hidden_states, encoder_hidden_states): batch_size, channels, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channels, height * width).permute(0, 2, 1) if self.kv_mapper is not None: encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states)) # NOTE: This is actually a cross attention layer hidden_states, residual = self.attn_layer_norm(hidden_states) hidden_states = self.attention(hidden_states, encoder_hidden_states) hidden_states, residual = self.crossattn_layer_norm(hidden_states, residual) hidden_states = self.crossattention(hidden_states, encoder_hidden_states) hidden_states = hidden_states + residual hidden_states = hidden_states.permute(0, 2, 1).view(batch_size, channels, height, width) return hidden_states class Attention(nn.Module): def __init__(self, hidden_size: int, context_dim: int, num_heads: int, config: MaskGiTUViT_v2Config): super().__init__() self.config = config self.hidden_size = hidden_size self.num_heads = num_heads self.head_dim = self.hidden_size // num_heads if self.hidden_size % self.num_heads != 0: raise ValueError( f"self.hidden_size: {self.hidden_size} must be divisible by self.num_heads: {self.num_heads}" ) self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()) self.query = nn.Linear(self.hidden_size, self.hidden_size, bias=self.config.use_bias) self.key = nn.Linear(context_dim, self.hidden_size, bias=self.config.use_bias) self.value = nn.Linear(context_dim, self.hidden_size, bias=self.config.use_bias) self.out = nn.Linear(self.hidden_size, self.hidden_size, bias=self.config.use_bias) self.dropout = nn.Dropout(self.config.attention_dropout) self.use_memory_efficient_attention_xformers = False self.xformers_attention_op = None def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None ): if use_memory_efficient_attention_xformers and not is_xformers_available: raise ImportError("Please install xformers to use memory efficient attention") self.use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers self.xformers_attention_op = attention_op def forward(self, hidden_states, context): batch, q_seq_len, _ = hidden_states.shape kv_seq_len = context.shape[1] query = self.query(hidden_states) key = self.key(context) value = self.value(context) query = query.view(batch, q_seq_len, self.num_heads, self.head_dim) # (B, T, nh, hs) key = key.view(batch, kv_seq_len, self.num_heads, self.head_dim) # (B, T, nh, hs) value = value.view(batch, kv_seq_len, self.num_heads, self.head_dim) # (B, T, nh, hs) if self.use_memory_efficient_attention_xformers: attn_output = xops.memory_efficient_attention( query, key, value, op=self.xformers_attention_op, p=self.config.attention_dropout if self.training else 0.0, ) attn_output = attn_output.view(batch, q_seq_len, self.hidden_size) else: attn_output = self.attention(query, key, value) attn_output = self.out(attn_output) return attn_output def attention(self, query, key, value, attention_mask=None): batch, seq_len = query.shape[:2] kv_seq_len = key.shape[1] query, key, value = map(lambda t: t.transpose(1, 2).contiguous(), (query, key, value)) # (B, nh, T, hs) attn_weights = torch.baddbmm( input=torch.zeros(batch * self.num_heads, seq_len, kv_seq_len, dtype=query.dtype, device=query.device), batch1=query.view(batch * self.num_heads, seq_len, self.head_dim), batch2=key.view(batch * self.num_heads, kv_seq_len, self.head_dim).transpose(1, 2), alpha=1 / self.scale_attn, ) attn_weights = attn_weights.view(batch, self.num_heads, seq_len, kv_seq_len) # -1 is kv_seq_len # Apply the attention mask if attention_mask is not None: attn_weights = torch.masked_fill(attn_weights, attention_mask, torch.finfo(query.dtype).min) attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = self.dropout(attn_weights) attn_output = torch.matmul(attn_weights, value) # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) # re-assemble all head outputs side by side attn_output = attn_output.transpose(1, 2).contiguous().view(batch, seq_len, self.hidden_size) return attn_output def FeedForward(config: MaskGiTUViT_v2Config): if config.use_fused_mlp: return FusedGeLUFeedForward(config) else: return GLUFeedForward(config) class GLUFeedForward(nn.Module): def __init__(self, config: MaskGiTUViT_v2Config): super().__init__() self.pre_mlp_layer_norm = LayerNorm(config.hidden_size, config) self.adaLN_modulation = AdaLNModulation(config.hidden_size, config) self.wi_0 = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.use_bias) self.wi_1 = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.use_bias) self.dropout = nn.Dropout(config.hidden_dropout) self.wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.use_bias) def forward(self, hidden_states, cond_embeds, residual=None): hidden_states, residual = self.pre_mlp_layer_norm(hidden_states, residual=residual) hidden_states = self.adaLN_modulation(hidden_states, cond_embeds) hidden_gelu = F.gelu(self.wi_0(hidden_states)) hidden_linear = self.wi_1(hidden_states) hidden_states = hidden_gelu * hidden_linear hidden_states = self.dropout(hidden_states) hidden_states = self.wo(hidden_states) return hidden_states, residual class FusedGeLUFeedForward(nn.Module): def __init__(self, config: MaskGiTUViT_v2Config): super().__init__() self.pre_mlp_layer_norm = LayerNorm(config.hidden_size, config) self.adaLN_modulation = AdaLNModulation(config.hidden_size, config) self.wi_0 = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.use_bias) self.dropout = nn.Dropout(config.hidden_dropout) self.wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.use_bias) def forward(self, hidden_states, cond_embeds, residual=None): if fused_mlp_func is None: raise ImportError("Please install flash_attn to use fused mlp") hidden_states, residual = self.pre_mlp_layer_norm(hidden_states, residual=residual) hidden_states = self.adaLN_modulation(hidden_states, cond_embeds) dtype = hidden_states.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype() cuda_ver = tuple(map(int, torch.version.cuda.split("."))) if torch.cuda.get_device_capability("cuda") == (9, 0): heuristic = -1 elif cuda_ver >= (11, 8): heuristic = 0 elif dtype == torch.float16: heuristic = 1 else: heuristic = -1 hidden_states = fused_mlp_func( hidden_states, self.wi_0.weight, self.wo.weight, self.wi_0.bias, self.wo.bias, activation="gelu_approx", save_pre_act=self.training, return_residual=False, checkpoint_lvl=0, heuristic=heuristic, ) return hidden_states, residual # misc blocks class ConvMlmLayer(nn.Module): def __init__(self, config: MaskGiTUViT_v2Config): super().__init__() self.config = config self.conv1 = nn.Conv2d( self.config.block_out_channels[0], self.config.in_channels, kernel_size=1, bias=self.config.use_bias ) self.layer_norm = Norm2D(self.config.in_channels, config) self.conv2 = nn.Conv2d( self.config.in_channels, self.config.codebook_size, kernel_size=1, bias=self.config.use_bias ) def forward(self, hidden_states): batch_size, seq_length, hidden_size = hidden_states.shape resolution = int(seq_length**0.5) hidden_states = hidden_states.view(batch_size, resolution, resolution, hidden_size).permute(0, 3, 1, 2) hidden_states = self.conv1(hidden_states) hidden_states = self.layer_norm(hidden_states) logits = self.conv2(hidden_states) logits = logits.permute(0, 2, 3, 1).view(batch_size, -1, self.config.codebook_size) return logits class AdaLNModulation(nn.Module): def __init__(self, hidden_size: int, config: MaskGiTUViT_v2Config): super().__init__() self.mapper = nn.Linear(config.hidden_size, hidden_size * 2, bias=config.use_bias) def forward(self, hidden_states, cond_embeds): cond_embeds = F.silu(cond_embeds) scale, shift = self.mapper(cond_embeds).chunk(2, dim=1) if hidden_states.dim() > 3: scale, shift = scale[:, :, None, None], shift[:, :, None, None] else: scale, shift = scale[:, None], shift[:, None] return hidden_states * (1 + scale) + shift