def forward()

in muse/modeling_transformer.py [0:0]


    def forward(self, x, x_skip=None, cond_embeds=None, encoder_hidden_states=None, **kwargs):
        if self.add_downsample:
            x = self.downsample(x)

        output_states = ()
        for i, res_block in enumerate(self.res_blocks):
            if self.training and self.gradient_checkpointing:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs)

                    return custom_forward

                x = torch.utils.checkpoint.checkpoint(create_custom_forward(res_block), x, x_skip)
                if self.has_attention:
                    x = torch.utils.checkpoint.checkpoint(
                        create_custom_forward(self.attention_blocks[i]), x, encoder_hidden_states
                    )
            else:
                x = res_block(x, x_skip, cond_embeds=cond_embeds)
                if self.has_attention:
                    x = self.attention_blocks[i](x, encoder_hidden_states)

            output_states += (x,)
        return x, output_states