def forward()

in muse/modeling_transformer_v2.py [0:0]


    def forward(self, x, cond_embeds, encoder_hidden_states):
        if self.downsample is not None:
            x = self.downsample(x)

        for res_block, attention_block in zip(self.res_blocks, self.attention_blocks):
            if self.training and self.gradient_checkpointing:
                res_block_ = lambda *args: checkpoint(res_block, *args)
                attention_block_ = lambda *args: checkpoint(attention_block, *args)
            else:
                res_block_ = res_block
                attention_block_ = attention_block

            x = res_block_(x, cond_embeds)
            x = attention_block_(x, encoder_hidden_states)

        return x