in muse/modeling_movq.py [0:0]
def __init__(self, config, curr_res: int, block_idx: int, zq_ch: int):
super().__init__()
self.config = config
self.block_idx = block_idx
self.curr_res = curr_res
if self.block_idx == self.config.num_resolutions - 1:
block_in = self.config.hidden_channels * self.config.channel_mult[-1]
else:
block_in = self.config.hidden_channels * self.config.channel_mult[self.block_idx + 1]
block_out = self.config.hidden_channels * self.config.channel_mult[self.block_idx]
res_blocks = []
attn_blocks = []
for _ in range(self.config.num_res_blocks + 1):
res_blocks.append(ResnetBlock(block_in, block_out, zq_ch=zq_ch, dropout=self.config.dropout))
block_in = block_out
if self.curr_res in self.config.attn_resolutions:
attn_blocks.append(AttnBlock(block_in, zq_ch=zq_ch))
self.block = nn.ModuleList(res_blocks)
self.attn = nn.ModuleList(attn_blocks)
self.upsample = None
if self.block_idx != 0:
self.upsample = Upsample(block_in, self.config.resample_with_conv)