def forward()

in muse/modeling_transformer.py [0:0]


    def forward(self, x, x_skip, **kwargs):
        for res_block in self.res_blocks:
            # pop res hidden states
            res_hidden_states = x_skip[-1]
            x_skip = x_skip[:-1]
            x = torch.cat([x, res_hidden_states], dim=1)
            if self.training and self.gradient_checkpointing:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs)

                    return custom_forward

                x = torch.utils.checkpoint.checkpoint(create_custom_forward(res_block), x)
            else:
                x = res_block(x)

        if self.add_upsample:
            if x.shape[0] >= 64:
                x = x.contiguous()
            x = F.interpolate(x, scale_factor=2.0, mode="nearest")
            x = self.upsample_conv(x)

        return x