muse/modeling_transformer.py (1,173 lines of code) (raw):

# coding=utf-8 # Copyright 2023 The HuggingFace Inc. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # This file is heavily inspired by the original implementation from https://github.com/lucidrains/muse-maskgit-pytorch import math from functools import partial from typing import Callable, Optional import numpy as np import torch import torch.nn.functional as F from einops import rearrange from torch import nn from torch.utils.checkpoint import checkpoint from tqdm import tqdm from .modeling_transformer_v2 import MaskGiTUViT_v2 from .modeling_utils import ConfigMixin, ModelMixin, register_to_config from .sampling import cosine_schedule, gumbel_sample, mask_by_random_topk, top_k try: import xformers.ops as xops is_xformers_available = True except ImportError: is_xformers_available = False MaskGiTUViT = MaskGiTUViT_v2 # classifier free guidance functions def uniform(shape, min=0, max=1, device=None): return torch.zeros(shape, device=device).float().uniform_(0, 1) def prob_mask_like(shape, prob, device=None): if prob == 1: return torch.ones(shape, device=device, dtype=torch.bool) elif prob == 0: return torch.zeros(shape, device=device, dtype=torch.bool) else: return uniform(shape, device=device) < prob def make_attention_mask( query_input: torch.Tensor, key_input: torch.Tensor, pairwise_fn: Callable = torch.mul ) -> torch.Tensor: # [batch, len_q, len_kv] mask = pairwise_fn( # [batch, len_q] -> [batch, len_q, 1] torch.unsqueeze(query_input, axis=-1), # [batch, len_q] -> [batch, 1, len_kv] torch.unsqueeze(key_input, axis=-2), ) # [batch, 1, len_q, len_kv]. This creates the head dim. mask = torch.unsqueeze(mask, axis=-3) return (1.0 - mask).type(torch.bool) try: from apex.normalization import FusedRMSNorm as RMSNorm # noqa except Exception: class RMSNorm(nn.Module): def __init__(self, normalized_shape, eps=1e-6, elementwise_affine=True): super().__init__() self.elementwise_affine = elementwise_affine if elementwise_affine: self.weight = nn.Parameter(torch.ones(normalized_shape)) self.variance_epsilon = eps def forward(self, input): input_dtype = input.dtype variance = input.to(torch.float32).pow(2).mean(-1, keepdim=True) input = input * torch.rsqrt(variance + self.variance_epsilon) if self.elementwise_affine: # convert into half-precision if necessary if self.weight.dtype in [torch.float16, torch.bfloat16]: input = input.to(self.weight.dtype) input = input * self.weight else: input = input.to(input_dtype) return input def sinusoidal_enocde(features, embedding_dim, max_positions=10000): half_dim = embedding_dim // 2 emb = math.log(max_positions) / half_dim emb = ( torch.arange( 0, half_dim, device=features.device, dtype=torch.float32, ) .mul(-emb) .exp() ) emb = features[:, None] * emb[None, :] emb = torch.cat([emb.cos(), emb.sin()], dim=1) if embedding_dim % 2 == 1: # zero pad emb = nn.functional.pad(emb, (0, 1), mode="constant") return emb # layer norm without bias class LayerNorm(nn.Module): def __init__(self, dim, eps=1e-5, use_bias=False, elementwise_affine=True): super().__init__() self.dim = dim if elementwise_affine: self.weight = nn.Parameter(torch.ones(dim)) self.bias = nn.Parameter(torch.zeros(dim)) if use_bias else None else: self.weight = None self.bias = None self.eps = eps def forward(self, x): return F.layer_norm(x, (self.dim,), self.weight, self.bias, self.eps) class AdaLNModulation(nn.Module): def __init__(self, cond_embed_dim, hidden_size, use_bias=False): super().__init__() self.mapper = nn.Linear(cond_embed_dim, hidden_size * 2, bias=use_bias) def forward(self, hidden_states, cond_embeds): cond_embeds = F.silu(cond_embeds) scale, shift = self.mapper(cond_embeds).chunk(2, dim=1) if hidden_states.dim() > 3: scale, shift = scale[:, :, None, None], shift[:, :, None, None] else: scale, shift = scale[:, None], shift[:, None] return hidden_states * (1 + scale) + shift class Attention(nn.Module): def __init__(self, hidden_size, num_heads, encoder_hidden_size=None, attention_dropout=0.0, use_bias=False): super().__init__() self.hidden_size = hidden_size self.num_heads = num_heads self.head_dim = hidden_size // num_heads self.attention_dropout = attention_dropout if self.head_dim * self.num_heads != self.hidden_size: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.hidden_size} and" f" `num_heads`: {self.num_heads})." ) self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype()) self.query = nn.Linear(self.hidden_size, self.hidden_size, bias=use_bias) kv_hidden_size = self.hidden_size if encoder_hidden_size is None else encoder_hidden_size self.key = nn.Linear(kv_hidden_size, self.hidden_size, bias=use_bias) self.value = nn.Linear(kv_hidden_size, self.hidden_size, bias=use_bias) self.out = nn.Linear(self.hidden_size, self.hidden_size, bias=use_bias) self.dropout = nn.Dropout(attention_dropout) self.use_memory_efficient_attention_xformers = False self.xformers_attention_op = None def set_use_memory_efficient_attention_xformers( self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None ): if use_memory_efficient_attention_xformers and not is_xformers_available: raise ImportError("Please install xformers to use memory efficient attention") self.use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers self.xformers_attention_op = attention_op def forward(self, hidden_states, encoder_hidden_states=None, encoder_attention_mask=None): if encoder_attention_mask is not None and self.use_memory_efficient_attention_xformers: raise ValueError("Memory efficient attention does not yet support encoder attention mask") context = hidden_states if encoder_hidden_states is None else encoder_hidden_states batch, q_seq_len, _ = hidden_states.shape kv_seq_len = q_seq_len if encoder_hidden_states is None else encoder_hidden_states.shape[1] query = self.query(hidden_states) key = self.key(context) value = self.value(context) query = query.view(batch, q_seq_len, self.num_heads, self.head_dim) # (B, T, nh, hs) key = key.view(batch, kv_seq_len, self.num_heads, self.head_dim) # (B, T, nh, hs) value = value.view(batch, kv_seq_len, self.num_heads, self.head_dim) # (B, T, nh, hs) if self.use_memory_efficient_attention_xformers: attn_output = xops.memory_efficient_attention( query, key, value, op=self.xformers_attention_op, p=self.attention_dropout if self.training else 0.0 ) attn_output = attn_output.view(batch, q_seq_len, self.hidden_size) else: attention_mask = None if encoder_attention_mask is not None: src_attn_mask = torch.ones(batch, q_seq_len, dtype=torch.long, device=query.device) attention_mask = make_attention_mask(src_attn_mask, encoder_attention_mask, dtype=query.dtype) attn_output = self.attention(query, key, value, attention_mask) attn_output = self.out(attn_output) return attn_output def attention(self, query, key, value, attention_mask=None): batch, seq_len = query.shape[:2] kv_seq_len = key.shape[1] query, key, value = map(lambda t: t.transpose(1, 2).contiguous(), (query, key, value)) # (B, nh, T, hs) attn_weights = torch.baddbmm( input=torch.zeros(batch * self.num_heads, seq_len, kv_seq_len, dtype=query.dtype, device=query.device), batch1=query.view(batch * self.num_heads, seq_len, self.head_dim), batch2=key.view(batch * self.num_heads, kv_seq_len, self.head_dim).transpose(1, 2), alpha=1 / self.scale_attn, ) attn_weights = attn_weights.view(batch, self.num_heads, seq_len, kv_seq_len) # -1 is kv_seq_len # Apply the attention mask if attention_mask is not None: attn_weights = torch.masked_fill(attn_weights, attention_mask, torch.finfo(query.dtype).min) attn_weights = F.softmax(attn_weights, dim=-1) attn_weights = self.dropout(attn_weights) attn_output = torch.matmul(attn_weights, value) # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) # re-assemble all head outputs side by side attn_output = attn_output.transpose(1, 2).contiguous().view(batch, seq_len, self.hidden_size) return attn_output # U-ViT blocks # Adpated from https://github.com/dome272/Paella/blob/main/src_distributed/modules.py class AttentionBlock2D(nn.Module): def __init__( self, hidden_size, num_heads, encoder_hidden_size, attention_dropout=0.0, norm_type="layernorm", layer_norm_eps=1e-6, ln_elementwise_affine=True, use_bias=False, ): super().__init__() self.hidden_size = hidden_size norm_cls = partial(LayerNorm, use_bias=use_bias) if norm_type == "layernorm" else RMSNorm self.attn_layer_norm = norm_cls(self.hidden_size, eps=layer_norm_eps, elementwise_affine=ln_elementwise_affine) self.attention = Attention(hidden_size, num_heads, attention_dropout=attention_dropout, use_bias=use_bias) self.crossattn_layer_norm = norm_cls(hidden_size, eps=layer_norm_eps, elementwise_affine=ln_elementwise_affine) self.crossattention = Attention(hidden_size, num_heads, attention_dropout=attention_dropout, use_bias=use_bias) if encoder_hidden_size != hidden_size: self.kv_mapper = nn.Linear(encoder_hidden_size, hidden_size, bias=use_bias) else: self.kv_mapper = None def forward(self, hidden_states, encoder_hidden_states, encoder_attention_mask=None): # hidden_states -> (bs, hidden_size, height, width) # reshape to (bs, height * width, hidden_size) batch_size, channels, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channels, height * width).permute(0, 2, 1) # map encoder hidden states to hidden size of current layer if self.kv_mapper is not None: encoder_hidden_states = self.kv_mapper(F.silu(encoder_hidden_states)) # self attention residual = hidden_states hidden_states = self.attn_layer_norm(hidden_states) hidden_states = self.attention(hidden_states, encoder_hidden_states, encoder_attention_mask) hidden_states = hidden_states + residual # cross attention residual = hidden_states hidden_states = self.crossattn_layer_norm(hidden_states) hidden_states = self.crossattention(hidden_states, encoder_hidden_states, encoder_attention_mask) hidden_states = hidden_states + residual # reshape back to (bs, hidden_size, height, width) hidden_states = hidden_states.permute(0, 2, 1).view(batch_size, channels, height, width) return hidden_states class Norm2D(nn.Module): def __init__(self, dim, eps=1e-5, use_bias=False, norm_type="layernorm", elementwise_affine=True): super().__init__() if norm_type == "layernorm": self.norm = LayerNorm(dim, eps, use_bias, elementwise_affine=elementwise_affine) elif norm_type == "rmsnorm": self.norm = RMSNorm(dim, eps, elementwise_affine=elementwise_affine) def forward(self, x): return self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) class GlobalResponseNorm(nn.Module): "Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105" def __init__(self, dim): super().__init__() self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) def forward(self, x): Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) return self.gamma * (x * Nx) + self.beta + x class ResBlock(nn.Module): def __init__( self, in_channels, skip_channels=None, kernel_size=3, dropout=0.0, norm_type="layernorm", ln_elementwise_affine=True, add_cond_embeds=False, cond_embed_dim=None, use_bias=False, res_ffn_factor=4, **kwargs, ): super().__init__() self.depthwise = nn.Conv2d( in_channels + skip_channels, in_channels, kernel_size=kernel_size, padding=kernel_size // 2, groups=in_channels, bias=use_bias, ) self.norm = Norm2D( in_channels, eps=1e-6, norm_type=norm_type, use_bias=use_bias, elementwise_affine=ln_elementwise_affine ) self.channelwise = nn.Sequential( nn.Linear(in_channels, int(in_channels * res_ffn_factor), bias=use_bias), nn.GELU(), GlobalResponseNorm(int(in_channels * res_ffn_factor)), nn.Dropout(dropout), nn.Linear(int(in_channels * res_ffn_factor), in_channels, bias=use_bias), ) if add_cond_embeds: self.adaLN_modulation = AdaLNModulation( cond_embed_dim=cond_embed_dim, hidden_size=in_channels, use_bias=use_bias ) def forward(self, x, x_skip=None, cond_embeds=None): x_res = x if x_skip is not None: x = torch.cat([x, x_skip], dim=1) x = self.norm(self.depthwise(x)).permute(0, 2, 3, 1) x = self.channelwise(x).permute(0, 3, 1, 2) x = x + x_res if cond_embeds is not None: x = self.adaLN_modulation(x, cond_embeds) return x class ResnetBlockVanilla(nn.Module): def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, use_bias=False, **kwargs): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=use_bias) self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) self.dropout = torch.nn.Dropout(dropout) self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=use_bias) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = torch.nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=use_bias ) else: self.nin_shortcut = torch.nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=use_bias ) def forward(self, hidden_states, **kwargs): residual = hidden_states hidden_states = self.norm1(hidden_states) hidden_states = F.silu(hidden_states) hidden_states = self.conv1(hidden_states) hidden_states = self.norm2(hidden_states) hidden_states = F.silu(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.conv2(hidden_states) if self.in_channels != self.out_channels: if self.use_conv_shortcut: residual = self.conv_shortcut(residual) else: residual = self.nin_shortcut(residual) return residual + hidden_states class DownsampleBlock(nn.Module): def __init__( self, input_channels, output_channels=None, skip_channels=None, num_res_blocks=4, kernel_size=3, res_ffn_factor=4, dropout=0.0, norm_type="layernorm", ln_elementwise_affine=True, add_downsample=True, add_cond_embeds=False, cond_embed_dim=None, has_attention=False, num_heads=None, encoder_hidden_size=None, use_bias=False, **kwargs, ): super().__init__() self.add_downsample = add_downsample self.has_attention = has_attention if add_downsample: self.downsample = nn.Sequential( Norm2D( input_channels, eps=1e-6, use_bias=use_bias, norm_type=norm_type, elementwise_affine=ln_elementwise_affine, ), nn.Conv2d(input_channels, output_channels, kernel_size=2, stride=2, bias=use_bias), ) self.input_channels = output_channels else: self.input_channels = input_channels self.res_blocks = nn.ModuleList( [ ResBlock( self.input_channels, skip_channels=skip_channels, kernel_size=kernel_size, dropout=dropout, norm_type=norm_type, ln_elementwise_affine=ln_elementwise_affine, add_cond_embeds=add_cond_embeds, cond_embed_dim=cond_embed_dim, use_bias=use_bias, res_ffn_factor=res_ffn_factor, ) for _ in range(num_res_blocks) ] ) if has_attention: self.attention_blocks = nn.ModuleList( [ AttentionBlock2D( hidden_size=self.input_channels, num_heads=num_heads, encoder_hidden_size=encoder_hidden_size, attention_dropout=dropout, norm_type=norm_type, ln_elementwise_affine=ln_elementwise_affine, use_bias=use_bias, ) for _ in range(num_res_blocks) ] ) self.gradient_checkpointing = False def forward(self, x, x_skip=None, cond_embeds=None, encoder_hidden_states=None, **kwargs): if self.add_downsample: x = self.downsample(x) output_states = () for i, res_block in enumerate(self.res_blocks): if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward x = torch.utils.checkpoint.checkpoint(create_custom_forward(res_block), x, x_skip) if self.has_attention: x = torch.utils.checkpoint.checkpoint( create_custom_forward(self.attention_blocks[i]), x, encoder_hidden_states ) else: x = res_block(x, x_skip, cond_embeds=cond_embeds) if self.has_attention: x = self.attention_blocks[i](x, encoder_hidden_states) output_states += (x,) return x, output_states class UpsampleBlock(nn.Module): def __init__( self, input_channels, output_channels=None, skip_channels=None, num_res_blocks=4, kernel_size=3, res_ffn_factor=4, dropout=0.0, norm_type="layernorm", ln_elementwise_affine=True, add_upsample=True, add_cond_embeds=False, cond_embed_dim=None, has_attention=False, num_heads=None, encoder_hidden_size=None, use_bias=False, **kwargs, ): super().__init__() self.add_upsample = add_upsample self.has_attention = has_attention self.input_channels = input_channels self.output_channels = output_channels if output_channels is not None else input_channels self.res_blocks = nn.ModuleList( [ ResBlock( self.input_channels, skip_channels=skip_channels if i == 0 else 0, kernel_size=kernel_size, dropout=dropout, norm_type=norm_type, ln_elementwise_affine=ln_elementwise_affine, add_cond_embeds=add_cond_embeds, cond_embed_dim=cond_embed_dim, use_bias=use_bias, res_ffn_factor=res_ffn_factor, ) for i in range(num_res_blocks) ] ) if has_attention: self.attention_blocks = nn.ModuleList( [ AttentionBlock2D( hidden_size=self.input_channels, num_heads=num_heads, encoder_hidden_size=encoder_hidden_size, attention_dropout=dropout, norm_type=norm_type, ln_elementwise_affine=ln_elementwise_affine, use_bias=use_bias, ) for _ in range(num_res_blocks) ] ) if add_upsample: self.upsample = nn.Sequential( Norm2D( self.input_channels, eps=1e-6, norm_type=norm_type, use_bias=use_bias, elementwise_affine=ln_elementwise_affine, ), nn.ConvTranspose2d(self.input_channels, self.output_channels, kernel_size=2, stride=2, bias=use_bias), ) self.gradient_checkpointing = False def forward(self, x, x_skip=None, cond_embeds=None, encoder_hidden_states=None, **kwargs): for i, res_block in enumerate(self.res_blocks): x_res = x_skip[0] if i == 0 and x_skip is not None else None if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward x = torch.utils.checkpoint.checkpoint(create_custom_forward(res_block), x, x_res) if self.has_attention: x = torch.utils.checkpoint.checkpoint( create_custom_forward(self.attention_blocks[i]), x, encoder_hidden_states ) else: x = res_block(x, x_res, cond_embeds=cond_embeds) if self.has_attention: x = self.attention_blocks[i](x, encoder_hidden_states) if self.add_upsample: x = self.upsample(x) return x class DownsampleBlockVanilla(nn.Module): def __init__( self, input_channels, output_channels=None, num_res_blocks=4, dropout=0.0, add_downsample=True, use_bias=False, **kwargs, ): super().__init__() self.add_downsample = add_downsample res_blocks = [] for i in range(num_res_blocks): in_channels = input_channels if i == 0 else output_channels res_blocks.append( ResnetBlockVanilla( in_channels=in_channels, out_channels=output_channels, dropout=dropout, use_bias=use_bias ) ) self.res_blocks = nn.ModuleList(res_blocks) if add_downsample: self.downsample_conv = nn.Conv2d(output_channels, output_channels, 3, stride=2, bias=use_bias) self.gradient_checkpointing = False def forward(self, x, **kwargs): output_states = () for res_block in self.res_blocks: if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward x = torch.utils.checkpoint.checkpoint(create_custom_forward(res_block), x) else: x = res_block(x) output_states = output_states + (x,) if self.add_downsample: pad = (0, 1, 0, 1) x = torch.nn.functional.pad(x, pad, mode="constant", value=0) x = self.downsample_conv(x) output_states = output_states + (x,) return x, output_states class UpsampleBlockVanilla(nn.Module): def __init__( self, input_channels, output_channels, skip_channels=None, num_res_blocks=4, dropout=0.0, add_upsample=True, use_bias=False, **kwargs, ): super().__init__() self.add_upsample = add_upsample res_blocks = [] for i in range(num_res_blocks): res_skip_channels = input_channels if (i == num_res_blocks - 1) else output_channels resnet_in_channels = skip_channels if i == 0 else output_channels res_blocks.append( ResnetBlockVanilla( in_channels=resnet_in_channels + res_skip_channels, out_channels=output_channels, dropout=dropout, ) ) self.res_blocks = nn.ModuleList(res_blocks) if add_upsample: self.upsample_conv = nn.Conv2d(output_channels, output_channels, 3, padding=1) self.gradient_checkpointing = False def forward(self, x, x_skip, **kwargs): for res_block in self.res_blocks: # pop res hidden states res_hidden_states = x_skip[-1] x_skip = x_skip[:-1] x = torch.cat([x, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward x = torch.utils.checkpoint.checkpoint(create_custom_forward(res_block), x) else: x = res_block(x) if self.add_upsample: if x.shape[0] >= 64: x = x.contiguous() x = F.interpolate(x, scale_factor=2.0, mode="nearest") x = self.upsample_conv(x) return x # End U-ViT blocks # Normformer style GLU FeedForward class FeedForward(nn.Module): def __init__( self, hidden_size, intermediate_size, hidden_dropout=0.0, norm_type="layernorm", layer_norm_eps=1e-5, ln_elementwise_affine=True, use_normformer=True, add_cond_embeds=False, cond_embed_dim=None, use_bias=False, ffn_type="glu", # glu or vanilla ): super().__init__() self.use_normformer = use_normformer self.ffn_type = ffn_type self.pre_mlp_layer_norm = LayerNorm( hidden_size, eps=layer_norm_eps, use_bias=use_bias, elementwise_affine=ln_elementwise_affine ) self.wi_0 = nn.Linear(hidden_size, intermediate_size, bias=use_bias) if ffn_type == "glu": self.wi_1 = nn.Linear(hidden_size, intermediate_size, bias=use_bias) if use_normformer: norm_cls = partial(LayerNorm, use_bias=use_bias) if norm_type == "layernorm" else RMSNorm self.mid_mlp_layer_norm = norm_cls( intermediate_size, eps=layer_norm_eps, elementwise_affine=ln_elementwise_affine ) self.wo = nn.Linear(intermediate_size, hidden_size, bias=use_bias) self.dropout = nn.Dropout(hidden_dropout) if add_cond_embeds: self.adaLN_modulation = AdaLNModulation( cond_embed_dim=cond_embed_dim, hidden_size=hidden_size, use_bias=use_bias ) def forward(self, hidden_states: torch.FloatTensor, cond_embeds=None) -> torch.FloatTensor: hidden_states = self.pre_mlp_layer_norm(hidden_states) if cond_embeds is not None: hidden_states = self.adaLN_modulation(hidden_states, cond_embeds) hidden_gelu = F.gelu(self.wi_0(hidden_states)) if self.ffn_type == "glu": hidden_linear = self.wi_1(hidden_states) hidden_states = hidden_gelu * hidden_linear else: hidden_states = hidden_gelu if self.use_normformer: hidden_states = self.mid_mlp_layer_norm(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.wo(hidden_states) return hidden_states # PreLN Transformer layer class TransformerLayer(nn.Module): def __init__( self, hidden_size, intermediate_size, num_attention_heads, encoder_hidden_size=1024, add_cross_attention=False, hidden_dropout=0.0, attention_dropout=0.0, norm_type="layernorm", layer_norm_eps=1e-5, ln_elementwise_affine=True, use_normformer=True, add_cond_embeds=False, cond_embed_dim=None, ffn_type="glu", use_bias=False, ): super().__init__() self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_attention_heads = num_attention_heads self.use_normformer = use_normformer norm_cls = partial(LayerNorm, use_bias=use_bias) if norm_type == "layernorm" else RMSNorm self.attn_layer_norm = norm_cls(self.hidden_size, eps=layer_norm_eps, elementwise_affine=ln_elementwise_affine) self.attention = Attention( self.hidden_size, self.num_attention_heads, attention_dropout=attention_dropout, use_bias=use_bias ) if use_normformer: self.post_attn_layer_norm = norm_cls( self.hidden_size, eps=layer_norm_eps, elementwise_affine=ln_elementwise_affine ) self.ffn = FeedForward( self.hidden_size, self.intermediate_size, hidden_dropout, norm_type, layer_norm_eps, ln_elementwise_affine, use_normformer, add_cond_embeds, cond_embed_dim, use_bias, ffn_type, ) if add_cross_attention: self.crossattn_layer_norm = norm_cls( self.hidden_size, eps=layer_norm_eps, elementwise_affine=ln_elementwise_affine ) self.crossattention = Attention( self.hidden_size, self.num_attention_heads, encoder_hidden_size, attention_dropout, use_bias ) if use_normformer: self.post_crossattn_layer_norm = norm_cls( self.hidden_size, eps=layer_norm_eps, elementwise_affine=ln_elementwise_affine ) if add_cond_embeds: self.self_attn_adaLN_modulation = AdaLNModulation( cond_embed_dim=cond_embed_dim, hidden_size=hidden_size, use_bias=use_bias ) if add_cross_attention: self.cross_attn_adaLN_modulation = AdaLNModulation( cond_embed_dim=cond_embed_dim, hidden_size=hidden_size, use_bias=use_bias, ) def forward(self, hidden_states, encoder_hidden_states=None, encoder_attention_mask=None, cond_embeds=None): residual = hidden_states hidden_states = self.attn_layer_norm(hidden_states) if cond_embeds is not None: hidden_states = self.self_attn_adaLN_modulation(hidden_states, cond_embeds) attention_output = self.attention(hidden_states) if self.use_normformer: attention_output = self.post_attn_layer_norm(attention_output) hidden_states = residual + attention_output if encoder_hidden_states is not None: residual = hidden_states # TODO: should norm be applied to encoder_hidden_states as well? hidden_states = self.crossattn_layer_norm(hidden_states) if cond_embeds is not None: hidden_states = self.cross_attn_adaLN_modulation(hidden_states, cond_embeds) attention_output = self.crossattention( hidden_states, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, ) if self.use_normformer: attention_output = self.post_crossattn_layer_norm(attention_output) hidden_states = residual + attention_output residual = hidden_states hidden_states = self.ffn(hidden_states, cond_embeds=cond_embeds) hidden_states = residual + hidden_states return hidden_states class Embed(nn.Module): def __init__( self, vocab_size, embedding_size, hidden_size, hidden_dropout=0.0, max_position_embeddings=512, norm_type="layernorm", layer_norm_eps=1e-5, use_bias=False, layer_norm_embedddings=False, use_embeddings_project=False, ): super().__init__() self.vocab_size = vocab_size self.embedding_size = embedding_size self.hidden_size = hidden_size self.hidden_dropout = hidden_dropout self.max_position_embeddings = max_position_embeddings self.layer_norm_embedddings = layer_norm_embedddings self.use_embeddings_project = use_embeddings_project self.word_embeddings = nn.Embedding(self.vocab_size, self.embedding_size) self.position_embeddings = nn.Embedding(self.max_position_embeddings, self.embedding_size) self.dropout = nn.Dropout(self.hidden_dropout) if layer_norm_embedddings: norm_cls = partial(LayerNorm, use_bias=use_bias) if norm_type == "layernorm" else RMSNorm self.embeddings_ln = norm_cls(self.embedding_size, eps=layer_norm_eps) if use_embeddings_project: self.embedding_hidden_mapping = nn.Linear(self.embedding_size, self.hidden_size, bias=use_bias) def forward(self, input_ids): seq_length = input_ids.shape[-1] position_ids = torch.arange(seq_length)[None, :].to(input_ids.device) word_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) input_embeddings = word_embeddings + position_embeddings if self.layer_norm_embedddings: input_embeddings = self.embeddings_ln(input_embeddings) if self.use_embeddings_project: input_embeddings = self.embedding_hidden_mapping(input_embeddings) input_embeddings = self.dropout(input_embeddings) return input_embeddings class MlmLayer(nn.Module): def __init__( self, hidden_size, vocab_size, norm_type="layernorm", layer_norm_eps=1e-5, use_mlm_layernorm=True, use_bias=False, ): super().__init__() self.hidden_size = hidden_size self.use_mlm_layernorm = use_mlm_layernorm self.mlm_dense = nn.Linear(self.hidden_size, self.hidden_size, bias=use_bias) if use_mlm_layernorm: norm_cls = partial(LayerNorm, use_bias=use_bias) if norm_type == "layernorm" else RMSNorm self.mlm_ln = norm_cls(self.hidden_size, eps=layer_norm_eps) self.to_logits = nn.Linear(self.hidden_size, vocab_size, bias=use_bias) def forward(self, hidden_states): hidden_states = self.mlm_dense(hidden_states) hidden_states = F.gelu(hidden_states) if self.use_mlm_layernorm: hidden_states = self.mlm_ln(hidden_states) logits = self.to_logits(hidden_states) return logits class ConvEmbed(nn.Module): def __init__( self, vocab_size, embedding_size, hidden_size, patch_size=2, max_position_embeddings=256, norm_type="layernorm", ln_elementwise_affine=True, layer_norm_embedddings=False, layer_norm_eps=1e-5, use_position_embeddings=True, use_bias=False, ): super().__init__() self.hidden_size = hidden_size self.patch_size = patch_size self.max_position_embeddings = max_position_embeddings self.use_position_embeddings = use_position_embeddings self.layer_norm_embedddings = layer_norm_embedddings self.embeddings = nn.Embedding(vocab_size, embedding_size) norm_cls = partial(LayerNorm, use_bias=use_bias) if norm_type == "layernorm" else RMSNorm self.layer_norm = norm_cls(embedding_size, eps=layer_norm_eps, elementwise_affine=ln_elementwise_affine) if patch_size > 1: self.pixel_unshuffle = nn.PixelUnshuffle(patch_size) self.conv = nn.Conv2d(embedding_size * (patch_size**2), hidden_size, kernel_size=1, bias=use_bias) if use_position_embeddings: self.position_embeddings = nn.Embedding(self.max_position_embeddings, hidden_size) if self.layer_norm_embedddings: self.embeddings_ln = Norm2D( hidden_size, eps=layer_norm_eps, norm_type=norm_type, elementwise_affine=ln_elementwise_affine ) def forward(self, input_ids): batch_size, seq_length = input_ids.shape height, width = int(seq_length**0.5), int(seq_length**0.5) input_ids = input_ids.view(-1, height, width) embeddings = self.embeddings(input_ids) embeddings = self.layer_norm(embeddings) embeddings = embeddings.permute(0, 3, 1, 2) if self.patch_size > 1: embeddings = self.pixel_unshuffle(embeddings) embeddings = self.conv(embeddings) if self.use_position_embeddings: embeddings = embeddings.permute(0, 2, 3, 1).view(batch_size, -1, self.hidden_size) position_ids = torch.arange(embeddings.shape[1])[None, :].to(input_ids.device) position_embeddings = self.position_embeddings(position_ids) embeddings = embeddings + position_embeddings if self.layer_norm_embedddings: embeddings = self.embeddings_ln(embeddings) return embeddings class ConvMlmLayer(nn.Module): def __init__( self, vocab_size, embedding_size, hidden_size, patch_size=2, norm_type="layernorm", ln_elementwise_affine=True, layer_norm_eps=1e-5, use_bias=False, ): super().__init__() self.vocab_size = vocab_size self.patch_size = patch_size self.conv1 = nn.Conv2d(hidden_size, embedding_size * (patch_size**2), kernel_size=1, bias=use_bias) if patch_size > 1: self.pixel_shuffle = nn.PixelShuffle(patch_size) self.layer_norm = Norm2D( embedding_size, norm_type=norm_type, eps=layer_norm_eps, use_bias=use_bias, elementwise_affine=ln_elementwise_affine, ) self.conv2 = nn.Conv2d(embedding_size, vocab_size, kernel_size=1, bias=use_bias) def forward(self, hidden_states): batch_size, seq_length, hidden_size = hidden_states.shape height, width = int(seq_length**0.5), int(seq_length**0.5) hidden_states = hidden_states.view(batch_size, height, width, hidden_size).permute(0, 3, 1, 2) hidden_states = self.conv1(hidden_states) if self.patch_size > 1: hidden_states = self.pixel_shuffle(hidden_states) hidden_states = self.layer_norm(hidden_states) logits = self.conv2(hidden_states) logits = logits.permute(0, 2, 3, 1).view(batch_size, -1, self.vocab_size) return logits class MaskGitTransformer(ModelMixin, ConfigMixin): _supports_gradient_checkpointing = True @register_to_config def __init__( self, vocab_size, # codebook_size + 1 (for the mask token), for class-conditioned generation it'll be codebook_size + num_classes + 1 hidden_size=768, embedding_size=None, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, hidden_dropout=0.1, attention_dropout=0.1, max_position_embeddings=256, # for clas-conditioned generation it'll be 256 + 1 (for the class token) add_cross_attention=False, encoder_hidden_size=1024, # T5-large project_encoder_hidden_states=False, initializer_range=0.02, norm_type="layernorm", # or rmsnorm layer_norm_eps=1e-5, use_normformer=True, use_encoder_layernorm=True, use_mlm_layer=True, use_mlm_layernorm=True, use_bias=False, codebook_size=1024, num_vq_tokens=256, num_classes=None, # set for class-conditioned generation use_codebook_size_for_output=False, use_conv_in_out=False, patch_size=1, **kwargs, ): super().__init__() self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_hidden_layers = num_hidden_layers self.num_attention_heads = num_attention_heads self.intermediate_size = intermediate_size self.hidden_dropout = hidden_dropout self.attention_dropout = attention_dropout self.max_position_embeddings = max_position_embeddings self.initializer_range = initializer_range self.embedding_size = embedding_size or hidden_size self.register_to_config(mask_token_id=vocab_size - 1) norm_cls = partial(LayerNorm, use_bias=use_bias) if norm_type == "layernorm" else RMSNorm if use_conv_in_out: self.embed = ConvEmbed( vocab_size, embedding_size, hidden_size, patch_size=patch_size, norm_type=norm_type, layer_norm_eps=layer_norm_eps, use_bias=use_bias, ) else: self.embed = Embed( self.vocab_size, self.hidden_size, self.hidden_size, self.hidden_dropout, self.max_position_embeddings, use_bias, norm_type, layer_norm_eps, ) if add_cross_attention is not None and project_encoder_hidden_states: # Cross attention self.encoder_proj = nn.Linear(encoder_hidden_size, hidden_size, bias=use_bias) self.encoder_proj_layer_norm = norm_cls(hidden_size, eps=layer_norm_eps) encoder_hidden_size = hidden_size self.transformer_layers = nn.ModuleList( [ TransformerLayer( hidden_size=self.hidden_size, intermediate_size=self.intermediate_size, num_attention_heads=self.num_attention_heads, encoder_hidden_size=encoder_hidden_size, add_cross_attention=add_cross_attention, hidden_dropout=self.hidden_dropout, attention_dropout=self.attention_dropout, norm_type=norm_type, layer_norm_eps=layer_norm_eps, use_normformer=use_normformer, use_bias=use_bias, ) for _ in range(self.num_hidden_layers) ] ) if use_encoder_layernorm: self.encoder_layer_norm = norm_cls(self.hidden_size, eps=layer_norm_eps) self.output_size = codebook_size if use_codebook_size_for_output else self.vocab_size if use_mlm_layer: if use_conv_in_out: self.mlm_layer = ConvMlmLayer( self.output_size, embedding_size, hidden_size, patch_size=patch_size, norm_type=norm_type, layer_norm_eps=layer_norm_eps, use_bias=use_bias, ) else: self.mlm_layer = MlmLayer( self.hidden_size, self.output_size, norm_type, layer_norm_eps, use_mlm_layernorm, use_bias ) else: self.to_logits = nn.Linear(self.hidden_size, self.output_size, bias=use_bias) self.gradient_checkpointing = False self.apply(self._init_weights) def _init_weights(self, module): """ Initialize the weights according to the original implementation. https://github.com/google-research/maskgit/blob/main/maskgit/nets/maskgit_transformer.py#L37 """ # TODO: make this configurable if isinstance(module, nn.Linear): nn.init.trunc_normal_(module.weight, std=self.config.initializer_range) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): nn.init.trunc_normal_(module.weight, std=self.config.initializer_range) elif isinstance(module, (nn.LayerNorm, RMSNorm)): if hasattr(module, "weight") and module.weight is not None: module.weight.data.fill_(1.0) if hasattr(module, "bias") and module.bias is not None: module.bias.data.zero_() def _set_gradient_checkpointing(self, module, value=False): self.gradient_checkpointing = True def forward( self, input_ids, encoder_hidden_states=None, encoder_attention_mask=None, labels=None, label_smoothing=0.0, cond_dropout_prob=0.0, **kwargs, ): if self.config.add_cross_attention and encoder_hidden_states is None: raise ValueError("If `add_cross_attention` is True, `encoder_hidden_states` should be provided.") hidden_states = self.embed(input_ids) if encoder_hidden_states is not None and self.config.project_encoder_hidden_states: encoder_hidden_states = self.encoder_proj(encoder_hidden_states) encoder_hidden_states = self.encoder_proj_layer_norm(encoder_hidden_states) # condition dropout for classifier free guidance if encoder_hidden_states is not None and self.training and cond_dropout_prob > 0.0: batch_size = encoder_hidden_states.shape[0] mask = prob_mask_like((batch_size, 1, 1), 1.0 - cond_dropout_prob, encoder_hidden_states.device) encoder_hidden_states = encoder_hidden_states * mask for layer in self.transformer_layers: if self.gradient_checkpointing: def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward hidden_states = checkpoint( create_custom_forward(layer), hidden_states, encoder_hidden_states, encoder_attention_mask ) else: hidden_states = layer( hidden_states, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, ) if self.config.use_encoder_layernorm: hidden_states = self.encoder_layer_norm(hidden_states) if self.config.use_mlm_layer: logits = self.mlm_layer(hidden_states) else: logits = self.to_logits(hidden_states) if labels is not None: loss = F.cross_entropy( logits.view(-1, self.output_size), labels.view(-1), ignore_index=-100, label_smoothing=label_smoothing ) return logits, loss return logits def generate( self, input_ids: torch.LongTensor = None, class_ids: torch.LongTensor = None, encoder_hidden_states: torch.FloatTensor = None, temperature=1.0, topk_filter_thres=0.9, can_remask_prev_masked=False, # TODO: implement this timesteps=18, # ideal number of steps is 18 in maskgit paper guidance_scale=3, noise_schedule: Callable = cosine_schedule, use_tqdm=True, ): # begin with all image token ids masked mask_token_id = self.config.mask_token_id seq_len = self.config.num_vq_tokens batch_size = len(class_ids) if class_ids is not None else encoder_hidden_states.shape[0] shape = (batch_size, seq_len) # shift the class ids by the codebook size if class_ids is not None: class_ids += self.config.codebook_size # initialize with all image tokens masked if input_ids is not None: input_ids = torch.ones(shape, dtype=torch.long, device=self.device) * mask_token_id scores = torch.zeros(shape, dtype=torch.float32, device=self.device) starting_temperature = temperature iterate_over = zip(torch.linspace(0, 1, timesteps, device=self.device), reversed(range(timesteps))) if use_tqdm: iterate_over = tqdm(iterate_over, total=timesteps) for timestep, steps_until_x0 in iterate_over: rand_mask_prob = noise_schedule(timestep) num_token_masked = max(int((rand_mask_prob * seq_len).item()), 1) masked_indices = scores.topk(num_token_masked, dim=-1).indices input_ids = input_ids.scatter(1, masked_indices, mask_token_id) # prepend class token to input_ids if class_ids is not None: input_ids = torch.cat([class_ids[:, None], input_ids], dim=1) # classifier free guidance if encoder_hidden_states is not None and guidance_scale > 0: uncond_encoder_states = torch.zeros_like(encoder_hidden_states) model_input = torch.cat([input_ids] * 2) condition = torch.cat([encoder_hidden_states, uncond_encoder_states]) cond_logits, uncond_logits = self(model_input, encoder_hidden_states=condition).chunk(2) cond_logits = cond_logits[..., : self.config.codebook_size] uncond_logits = uncond_logits[..., : self.config.codebook_size] logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits) else: logits = self(input_ids, encoder_hidden_states=encoder_hidden_states) logits = logits[..., : self.config.codebook_size] # remove class token if class_ids is not None: input_ids = input_ids[:, 1:] logits = logits[:, 1:] filtered_logits = top_k(logits, topk_filter_thres) temperature = starting_temperature * (steps_until_x0 / timesteps) # temperature is annealed pred_ids = gumbel_sample(filtered_logits, temperature=temperature, dim=-1) is_mask = input_ids == mask_token_id input_ids = torch.where(is_mask, pred_ids, input_ids) probs_without_temperature = F.softmax(logits, dim=-1) scores = 1 - probs_without_temperature.gather(2, pred_ids[..., None]) scores = rearrange(scores, "... 1 -> ...") # TODO: use torch return input_ids def generate2( self, input_ids: torch.LongTensor = None, class_ids: torch.LongTensor = None, encoder_hidden_states: torch.FloatTensor = None, negative_embeds: torch.FloatTensor = None, temperature=1.0, timesteps=18, # ideal number of steps is 18 in maskgit paper guidance_scale=0, noise_schedule=cosine_schedule, generator: torch.Generator = None, **kwargs, ): """ Generate 1:1 similar to the original MaskGit repo https://github.com/google-research/maskgit/blob/main/maskgit/libml/parallel_decode.py#L79 """ # begin with all image token ids masked mask_token_id = self.config.mask_token_id seq_len = self.config.num_vq_tokens batch_size = len(class_ids) if class_ids is not None else encoder_hidden_states.shape[0] shape = (batch_size, seq_len) # shift the class ids by the codebook size if class_ids is not None: class_ids += self.config.codebook_size # initialize with all image tokens masked if input_ids is None: input_ids = torch.ones(shape, dtype=torch.long, device=self.device) * mask_token_id # classifier free guidance if encoder_hidden_states is not None and guidance_scale > 0: if negative_embeds is None: uncond_encoder_states = torch.zeros_like(encoder_hidden_states) else: uncond_encoder_states = negative_embeds condition = torch.cat([encoder_hidden_states, uncond_encoder_states]) model_conds = {"encoder_hidden_states": condition} for step in range(timesteps): # prepend class token to input_ids if class_ids is not None: input_ids = torch.cat([class_ids[:, None], input_ids], dim=1) if encoder_hidden_states is not None and guidance_scale > 0: model_input = torch.cat([input_ids] * 2) cond_logits, uncond_logits = self(model_input, **model_conds).chunk(2) cond_logits = cond_logits[..., : self.config.codebook_size] uncond_logits = uncond_logits[..., : self.config.codebook_size] logits = uncond_logits + guidance_scale * (cond_logits - uncond_logits) else: logits = self(input_ids, encoder_hidden_states=encoder_hidden_states) logits = logits[..., : self.config.codebook_size] # remove class token if class_ids is not None: input_ids = input_ids[:, 1:] logits = logits[:, 1:] # Samples the ids using categorical sampling: [batch_size, seq_length]. probs = logits.softmax(dim=-1) sampled = probs.reshape(-1, logits.size(-1)) sampled_ids = torch.multinomial(sampled, 1, generator=generator)[:, 0].view(*logits.shape[:-1]) # Just updates the masked tokens. unknown_map = input_ids == mask_token_id sampled_ids = torch.where(unknown_map, sampled_ids, input_ids) # Defines the mask ratio for the next round. The number to mask out is # determined by mask_ratio * unknown_number_in_the_beginning. ratio = 1.0 * (step + 1) / timesteps mask_ratio = noise_schedule(torch.tensor(ratio)) # Computes the probabilities of each selected tokens. selected_probs = torch.gather(probs, -1, sampled_ids.long()[..., None]) selected_probs = selected_probs.squeeze(-1) # Ignores the tokens given in the input by overwriting their confidence. selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max) # Gets mask lens for each sample in the batch according to the mask ratio. mask_len = (seq_len * mask_ratio).floor().unsqueeze(0).to(logits.device) # Keeps at least one of prediction in this round and also masks out at least # one and for the next iteration mask_len = torch.max( torch.tensor([1], device=logits.device), torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len) ) # Adds noise for randomness temperature = temperature * (1.0 - ratio) masking = mask_by_random_topk(mask_len, selected_probs, temperature, generator=generator) # Masks tokens with lower confidence. input_ids = torch.where(masking, mask_token_id, sampled_ids) return sampled_ids