def forward()

in muse/modeling_transformer_v2.py [0:0]


    def forward(self, hidden_states, cond_embeds, residual=None):
        if fused_mlp_func is None:
            raise ImportError("Please install flash_attn to use fused mlp")

        hidden_states, residual = self.pre_mlp_layer_norm(hidden_states, residual=residual)

        hidden_states = self.adaLN_modulation(hidden_states, cond_embeds)

        dtype = hidden_states.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype()
        cuda_ver = tuple(map(int, torch.version.cuda.split(".")))

        if torch.cuda.get_device_capability("cuda") == (9, 0):
            heuristic = -1
        elif cuda_ver >= (11, 8):
            heuristic = 0
        elif dtype == torch.float16:
            heuristic = 1
        else:
            heuristic = -1

        hidden_states = fused_mlp_func(
            hidden_states,
            self.wi_0.weight,
            self.wo.weight,
            self.wi_0.bias,
            self.wo.bias,
            activation="gelu_approx",
            save_pre_act=self.training,
            return_residual=False,
            checkpoint_lvl=0,
            heuristic=heuristic,
        )

        return hidden_states, residual