in muse/modeling_transformer.py [0:0]
def forward(self, x, x_skip=None, cond_embeds=None, encoder_hidden_states=None, **kwargs):
if self.add_downsample:
x = self.downsample(x)
output_states = ()
for i, res_block in enumerate(self.res_blocks):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
x = torch.utils.checkpoint.checkpoint(create_custom_forward(res_block), x, x_skip)
if self.has_attention:
x = torch.utils.checkpoint.checkpoint(
create_custom_forward(self.attention_blocks[i]), x, encoder_hidden_states
)
else:
x = res_block(x, x_skip, cond_embeds=cond_embeds)
if self.has_attention:
x = self.attention_blocks[i](x, encoder_hidden_states)
output_states += (x,)
return x, output_states