in muse/modeling_transformer_v2.py [0:0]
def forward(self, hidden_states, cond_embeds, residual=None):
if fused_mlp_func is None:
raise ImportError("Please install flash_attn to use fused mlp")
hidden_states, residual = self.pre_mlp_layer_norm(hidden_states, residual=residual)
hidden_states = self.adaLN_modulation(hidden_states, cond_embeds)
dtype = hidden_states.dtype if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype()
cuda_ver = tuple(map(int, torch.version.cuda.split(".")))
if torch.cuda.get_device_capability("cuda") == (9, 0):
heuristic = -1
elif cuda_ver >= (11, 8):
heuristic = 0
elif dtype == torch.float16:
heuristic = 1
else:
heuristic = -1
hidden_states = fused_mlp_func(
hidden_states,
self.wi_0.weight,
self.wo.weight,
self.wi_0.bias,
self.wo.bias,
activation="gelu_approx",
save_pre_act=self.training,
return_residual=False,
checkpoint_lvl=0,
heuristic=heuristic,
)
return hidden_states, residual