in muse/modeling_transformer_v2.py [0:0]
def forward(self, x, cond_embeds, encoder_hidden_states):
if self.downsample is not None:
x = self.downsample(x)
for res_block, attention_block in zip(self.res_blocks, self.attention_blocks):
if self.training and self.gradient_checkpointing:
res_block_ = lambda *args: checkpoint(res_block, *args)
attention_block_ = lambda *args: checkpoint(attention_block, *args)
else:
res_block_ = res_block
attention_block_ = attention_block
x = res_block_(x, cond_embeds)
x = attention_block_(x, encoder_hidden_states)
return x