muse/modeling_movq.py (460 lines of code) (raw):
# Taken from https://github.com/ai-forever/Kandinsky-2/blob/main/kandinsky2/vqgan/movq_modules.py
# pytorch_diffusion + derived encoder decoder
import math
from typing import Callable, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from .modeling_utils import ConfigMixin, ModelMixin, register_to_config
try:
import xformers.ops as xops
is_xformers_available = True
except ImportError:
is_xformers_available = False
class SpatialNorm(nn.Module):
def __init__(
self,
f_channels,
zq_channels,
norm_layer=nn.GroupNorm,
freeze_norm_layer=False,
add_conv=False,
**norm_layer_params,
):
super().__init__()
self.norm_layer = norm_layer(num_channels=f_channels, **norm_layer_params)
if freeze_norm_layer:
for p in self.norm_layer.parameters:
p.requires_grad = False
self.add_conv = add_conv
if self.add_conv:
self.conv = nn.Conv2d(zq_channels, zq_channels, kernel_size=3, stride=1, padding=1)
self.conv_y = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
self.conv_b = nn.Conv2d(zq_channels, f_channels, kernel_size=1, stride=1, padding=0)
def forward(self, f, zq):
f_size = f.shape[-2:]
zq = F.interpolate(zq, size=f_size, mode="nearest")
if self.add_conv:
zq = self.conv(zq)
norm_f = self.norm_layer(f)
new_f = norm_f * self.conv_y(zq) + self.conv_b(zq)
return new_f
def Normalize(in_channels, zq_ch, add_conv):
return SpatialNorm(
in_channels,
zq_ch,
norm_layer=nn.GroupNorm,
freeze_norm_layer=False,
add_conv=add_conv,
num_groups=32,
eps=1e-6,
affine=True,
)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x):
if self.with_conv:
pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
x = F.avg_pool2d(x, kernel_size=2, stride=2)
return x
class ResnetBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
zq_ch=None,
add_conv=False,
):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
if zq_ch:
self.norm1 = Normalize(in_channels, zq_ch, add_conv=add_conv)
else:
self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if zq_ch:
self.norm2 = Normalize(out_channels, zq_ch, add_conv=add_conv)
else:
self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, hidden_states, zq=None):
residual = hidden_states
if zq is not None:
hidden_states = self.norm1(hidden_states, zq)
else:
hidden_states = self.norm1(hidden_states)
hidden_states = F.silu(hidden_states)
hidden_states = self.conv1(hidden_states)
if zq is not None:
hidden_states = self.norm2(hidden_states, zq)
else:
hidden_states = self.norm2(hidden_states)
hidden_states = F.silu(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
residual = self.conv_shortcut(residual)
else:
residual = self.nin_shortcut(residual)
return hidden_states + residual
class AttnBlock(nn.Module):
def __init__(self, in_channels, zq_ch=None, add_conv=False):
super().__init__()
self.in_channels = in_channels
if zq_ch:
self.norm = Normalize(in_channels, zq_ch, add_conv=add_conv)
else:
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.q = nn.Linear(in_channels, in_channels)
self.k = nn.Linear(in_channels, in_channels)
self.v = nn.Linear(in_channels, in_channels)
self.proj_out = nn.Linear(in_channels, in_channels)
self.use_memory_efficient_attention_xformers = False
self.xformers_attention_op = None
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
if use_memory_efficient_attention_xformers and not is_xformers_available:
raise ImportError("Please install xformers to use memory efficient attention")
self.use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
self.xformers_attention_op = attention_op
def forward(self, hidden_states, zq=None):
residual = hidden_states
batch, channel, height, width = hidden_states.shape
if zq is not None:
hidden_states = self.norm(hidden_states, zq)
else:
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
scale = 1.0 / torch.sqrt(torch.tensor(channel, dtype=hidden_states.dtype, device=hidden_states.device))
query = self.q(hidden_states)
key = self.k(hidden_states)
value = self.v(hidden_states)
if self.use_memory_efficient_attention_xformers:
# Memory efficient attention
hidden_states = xops.memory_efficient_attention(
query, key, value, attn_bias=None, op=self.xformers_attention_op
)
else:
attention_scores = torch.baddbmm(
torch.empty(
query.shape[0],
query.shape[1],
key.shape[1],
dtype=query.dtype,
device=query.device,
),
query,
key.transpose(-1, -2),
beta=0,
alpha=scale,
)
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.transpose(-1, -2).view(batch, channel, height, width)
return hidden_states + residual
class UpsamplingBlock(nn.Module):
def __init__(self, config, curr_res: int, block_idx: int, zq_ch: int):
super().__init__()
self.config = config
self.block_idx = block_idx
self.curr_res = curr_res
if self.block_idx == self.config.num_resolutions - 1:
block_in = self.config.hidden_channels * self.config.channel_mult[-1]
else:
block_in = self.config.hidden_channels * self.config.channel_mult[self.block_idx + 1]
block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx]
res_blocks = []
attn_blocks = []
for _ in range(self.config.num_res_blocks + 1):
res_blocks.append(ResnetBlock(block_in, block_out, zq_ch=zq_ch, dropout=self.config.dropout))
block_in = block_out
if self.curr_res in self.config.attn_resolutions:
attn_blocks.append(AttnBlock(block_in, zq_ch=zq_ch))
self.block = nn.ModuleList(res_blocks)
self.attn = nn.ModuleList(attn_blocks)
self.upsample = None
if self.block_idx != 0:
self.upsample = Upsample(block_in, self.config.resample_with_conv)
def forward(self, hidden_states, zq):
for i, res_block in enumerate(self.block):
hidden_states = res_block(hidden_states, zq)
if len(self.attn) > 1:
hidden_states = self.attn[i](hidden_states, zq)
if self.upsample is not None:
hidden_states = self.upsample(hidden_states)
return hidden_states
class DownsamplingBlock(nn.Module):
def __init__(self, config, curr_res: int, block_idx: int):
super().__init__()
self.config = config
self.curr_res = curr_res
self.block_idx = block_idx
in_channel_mult = (1,) + tuple(self.config.channel_mult)
block_in = self.config.hidden_channels * in_channel_mult[self.block_idx]
block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx]
res_blocks = nn.ModuleList()
attn_blocks = nn.ModuleList()
for _ in range(self.config.num_res_blocks):
res_blocks.append(ResnetBlock(block_in, block_out, dropout=self.config.dropout))
block_in = block_out
if self.curr_res in self.config.attn_resolutions:
attn_blocks.append(AttnBlock(block_in))
self.block = res_blocks
self.attn = attn_blocks
self.downsample = None
if self.block_idx != self.config.num_resolutions - 1:
self.downsample = Downsample(block_in, self.config.resample_with_conv)
def forward(self, hidden_states):
for i, res_block in enumerate(self.block):
hidden_states = res_block(hidden_states)
if len(self.attn) > 1:
hidden_states = self.attn[i](hidden_states)
if self.downsample is not None:
hidden_states = self.downsample(hidden_states)
return hidden_states
class MidBlock(nn.Module):
def __init__(self, config, in_channels: int, zq_ch=None, dropout: float = 0.0):
super().__init__()
self.config = config
self.in_channels = in_channels
self.dropout = dropout
self.block_1 = ResnetBlock(
self.in_channels,
self.in_channels,
dropout=self.dropout,
zq_ch=zq_ch,
)
self.attn_1 = AttnBlock(self.in_channels, zq_ch=zq_ch)
self.block_2 = ResnetBlock(
self.in_channels,
self.in_channels,
dropout=self.dropout,
zq_ch=zq_ch,
)
def forward(self, hidden_states, zq=None):
hidden_states = self.block_1(hidden_states, zq)
hidden_states = self.attn_1(hidden_states, zq)
hidden_states = self.block_2(hidden_states, zq)
return hidden_states
class Encoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
# downsampling
self.conv_in = nn.Conv2d(
self.config.num_channels,
self.config.hidden_channels,
kernel_size=3,
stride=1,
padding=1,
)
curr_res = self.config.resolution
downsample_blocks = []
for i_level in range(self.config.num_resolutions):
downsample_blocks.append(DownsamplingBlock(self.config, curr_res, block_idx=i_level))
if i_level != self.config.num_resolutions - 1:
curr_res = curr_res // 2
self.down = nn.ModuleList(downsample_blocks)
# middle
mid_channels = self.config.hidden_channels * self.config.channel_mult[-1]
self.mid = MidBlock(config, mid_channels, dropout=self.config.dropout)
# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=mid_channels, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(
mid_channels,
self.config.z_channels,
kernel_size=3,
stride=1,
padding=1,
)
def forward(self, pixel_values):
# downsampling
hidden_states = self.conv_in(pixel_values)
for block in self.down:
hidden_states = block(hidden_states)
# middle
hidden_states = self.mid(hidden_states)
# end
hidden_states = self.norm_out(hidden_states)
hidden_states = F.silu(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class MoVQDecoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
# compute in_channel_mult, block_in and curr_res at lowest res
block_in = self.config.hidden_channels * self.config.channel_mult[self.config.num_resolutions - 1]
curr_res = self.config.resolution // 2 ** (self.config.num_resolutions - 1)
self.z_shape = (1, self.config.z_channels, curr_res, curr_res)
# z to block_in
self.conv_in = nn.Conv2d(
self.config.z_channels,
block_in,
kernel_size=3,
stride=1,
padding=1,
)
# middle
self.mid = MidBlock(config, block_in, zq_ch=self.config.quantized_embed_dim, dropout=self.config.dropout)
# upsampling
upsample_blocks = []
for i_level in reversed(range(self.config.num_resolutions)):
upsample_blocks.append(
UpsamplingBlock(self.config, curr_res, block_idx=i_level, zq_ch=self.config.quantized_embed_dim)
)
if i_level != 0:
curr_res = curr_res * 2
self.up = nn.ModuleList(list(reversed(upsample_blocks))) # reverse to get consistent order
# end
block_out = self.config.hidden_channels * self.config.channel_mult[0]
self.norm_out = Normalize(block_out, self.config.quantized_embed_dim, False)
self.conv_out = nn.Conv2d(
block_out,
self.config.num_channels,
kernel_size=3,
stride=1,
padding=1,
)
def forward(self, hidden_states, zq):
# z to block_in
hidden_states = self.conv_in(hidden_states)
# middle
hidden_states = self.mid(hidden_states, zq)
# upsampling
for block in reversed(self.up):
hidden_states = block(hidden_states, zq)
# end
hidden_states = self.norm_out(hidden_states, zq)
hidden_states = F.silu(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class VectorQuantizer(nn.Module):
"""
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
"""
# NOTE: due to a bug the beta term was applied to the wrong term. for
# backwards compatibility we use the buggy version by default, but you can
# specify legacy=False to fix it.
def __init__(self, num_embeddings, embedding_dim, commitment_cost, legacy=True):
r"""
Args:
num_embeddings: number of vectors in the quantized space.
embedding_dim: dimensionaity of the tensors in the quantized space.
Inputs to the modules must be in this format as well.
commitment_cost: scalar which controls the weighting of the loss terms
(see equation 4 in the paper https://arxiv.org/abs/1711.00937 - this variable is Beta).
"""
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.commitment_cost = commitment_cost
self.legacy = legacy
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
self.embedding.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings)
def forward(self, hidden_states, return_loss=False):
# reshape z -> (batch, height, width, channel) and flatten
hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous()
distances = self.compute_distances(hidden_states)
min_encoding_indices = torch.argmin(distances, axis=1).unsqueeze(1)
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.num_embeddings).to(hidden_states)
min_encodings.scatter_(1, min_encoding_indices, 1)
# get quantized latent vectors
z_q = torch.matmul(min_encodings, self.embedding.weight).view(hidden_states.shape)
# reshape to (batch, num_tokens)
min_encoding_indices = min_encoding_indices.reshape(hidden_states.shape[0], -1)
# compute loss for embedding
loss = None
if return_loss:
if not self.legacy:
loss = self.beta * torch.mean((z_q.detach() - hidden_states) ** 2) + torch.mean(
(z_q - hidden_states.detach()) ** 2
)
else:
loss = torch.mean((z_q.detach() - hidden_states) ** 2) + self.beta * torch.mean(
(z_q - hidden_states.detach()) ** 2
)
# preserve gradients
z_q = hidden_states + (z_q - hidden_states).detach()
# reshape back to match original input shape
z_q = z_q.permute(0, 3, 1, 2).contiguous()
return z_q, min_encoding_indices, loss
def compute_distances(self, hidden_states):
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
hidden_states_flattended = hidden_states.reshape((-1, self.embedding_dim))
return torch.cdist(hidden_states_flattended, self.embedding.weight)
def get_codebook_entry(self, indices):
# indices are expected to be of shape (batch, num_tokens)
# get quantized latent vectors
batch, num_tokens = indices.shape
z_q = self.embedding(indices)
z_q = z_q.reshape(batch, int(math.sqrt(num_tokens)), int(math.sqrt(num_tokens)), -1).permute(0, 3, 1, 2)
return z_q
def get_soft_code(self, hidden_states, temp=1.0, stochastic=False):
hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous() # (batch, height, width, channel)
distances = self.compute_distances(hidden_states) # (batch * height * width, num_embeddings)
soft_code = F.softmax(-distances / temp, dim=-1) # (batch * height * width, num_embeddings)
if stochastic:
code = torch.multinomial(soft_code, 1) # (batch * height * width, 1)
else:
code = distances.argmin(dim=-1) # (batch * height * width)
code = code.reshape(hidden_states.shape[0], -1) # (batch, height * width)
batch, num_tokens = code.shape
soft_code = soft_code.reshape(batch, num_tokens, -1) # (batch, height * width, num_embeddings)
return soft_code, code
def get_code(self, hidden_states):
# reshape z -> (batch, height, width, channel)
hidden_states = hidden_states.permute(0, 2, 3, 1).contiguous()
distances = self.compute_distances(hidden_states)
indices = torch.argmin(distances, axis=1).unsqueeze(1)
indices = indices.reshape(hidden_states.shape[0], -1)
return indices
class MOVQ(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
resolution: int = 256,
num_channels=3,
out_channels=3,
hidden_channels=128,
channel_mult=(1, 2, 2, 4),
num_res_blocks=2,
attn_resolutions=(32,),
z_channels=4,
double_z=False,
num_embeddings=16384,
quantized_embed_dim=4,
dropout=0.0,
resample_with_conv: bool = True,
commitment_cost: float = 0.25,
):
super().__init__()
self.config.num_resolutions = len(channel_mult)
self.config.reduction_factor = 2 ** (self.config.num_resolutions - 1)
self.config.latent_size = resolution // self.config.reduction_factor
self.encoder = Encoder(self.config)
self.decoder = MoVQDecoder(self.config)
self.quantize = VectorQuantizer(num_embeddings, quantized_embed_dim, commitment_cost=commitment_cost)
self.quant_conv = torch.nn.Conv2d(z_channels, quantized_embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(quantized_embed_dim, z_channels, 1)
def encode(self, pixel_values, return_loss=False):
hidden_states = self.encoder(pixel_values)
hidden_states = self.quant_conv(hidden_states)
quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states, return_loss)
output = (quantized_states, codebook_indices)
if return_loss:
output = output + (codebook_loss,)
return output
def decode(self, quant):
quant2 = self.post_quant_conv(quant)
dec = self.decoder(quant2, quant)
return dec
def decode_code(self, codebook_indices):
quantized_states = self.quantize.get_codebook_entry(codebook_indices)
reconstructed_pixel_values = self.decode(quantized_states)
return reconstructed_pixel_values
def get_code(self, pixel_values):
hidden_states = self.encoder(pixel_values)
hidden_states = self.quant_conv(hidden_states)
codebook_indices = self.quantize.get_code(hidden_states)
return codebook_indices
def forward(self, pixel_values, return_loss=False):
hidden_states = self.encoder(pixel_values)
hidden_states = self.quant_conv(hidden_states)
quantized_states, codebook_indices, codebook_loss = self.quantize(hidden_states, return_loss)
reconstructed_pixel_values = self.decode(quantized_states)
output = (reconstructed_pixel_values, codebook_indices)
if return_loss:
output = output + (codebook_loss,)
return output