in muse/modeling_taming_vqgan.py [0:0]
def __init__(self, config):
super().__init__()
self.config = config
# downsampling
self.conv_in = nn.Conv2d(
self.config.num_channels,
self.config.hidden_channels,
kernel_size=3,
stride=1,
padding=1,
)
curr_res = self.config.resolution
downsample_blocks = []
for i_level in range(self.config.num_resolutions):
downsample_blocks.append(DownsamplingBlock(self.config, curr_res, block_idx=i_level))
if i_level != self.config.num_resolutions - 1:
curr_res = curr_res // 2
self.down = nn.ModuleList(downsample_blocks)
# middle
mid_channels = self.config.hidden_channels * self.config.channel_mult[-1]
self.mid = MidBlock(config, mid_channels, self.config.no_attn_mid_block, self.config.dropout)
# end
self.norm_out = nn.GroupNorm(num_groups=32, num_channels=mid_channels, eps=1e-6, affine=True)
self.conv_out = nn.Conv2d(
mid_channels,
self.config.z_channels,
kernel_size=3,
stride=1,
padding=1,
)