in optimum/neuron/models/inference/llama/modeling_llama.py [0:0]
def _kernel_enabled_mlp(self, x, fused_rmsnorm, rmsnorm, residual):
fused_residual = residual is not None
logger.debug(
f"MLP: kernel, fused_residual={fused_residual}, fused_rmsnorm={fused_rmsnorm}, logical_nc_config={self.logical_nc_config}"
)
# Choose which kernel to call
if fused_residual:
assert not self.sequence_parallel_enabled, (
"MLP kernel cannot have both fused residual add and sequence parallel RMSnorm!"
)
# Using fused residual add
_mlp_fwd_call = nki_jit()(mlp_fused_add_isa_kernel)
else:
_mlp_fwd_call = nki_jit()(mlp_isa_kernel)
if self.sequence_parallel_enabled:
x = gather_from_sequence_parallel_region(x, self.sequence_dimension)
# Build output tensor
output_tensor_seqlen = x.shape[1]
if fused_residual:
# seqlen dim is doubled to store the residual add output
output_tensor_seqlen *= 2
output_tensor = torch.zeros(
size=(
x.shape[0], # batch size
output_tensor_seqlen,
self.hidden_size, # hidden size
),
dtype=x.dtype,
device=x.device,
)
# Grab weights
# all weights of the layers are stored in (out, in) shape
# unsqueeze so that shape of RMS gamma weight is [1, hidden] instead of [hidden]
ln_w = rmsnorm.weight.unsqueeze(0)
gate_w = self.gate_proj.weight.data
up_w = self.up_proj.weight.data
down_w = self.down_proj.weight.data
grid = (nc(self.logical_nc_config),)
if fused_residual:
_mlp_fwd_call[grid](
x, # attn_output
residual, # hidden
ln_w, # ln_w
gate_w, # gate_w
up_w, # up_w
down_w, # down_w
output_tensor, # out
fused_rmsnorm=fused_rmsnorm,
eps=self.rms_norm_eps,
kernel_name="MLP",
store_add=True,
)
original_seqlen = x.shape[1]
residual = output_tensor[:, original_seqlen:, :]
output_tensor = output_tensor[:, :original_seqlen, :]
else:
_mlp_fwd_call[grid](
x, # hidden
# should be fine to pass gamma is as a dummy even if not using fused rmsnorm
ln_w,
gate_w,
up_w,
down_w,
output_tensor, # out
# Run RMSNorm inside the kernel if NOT using SP rmsnorm
fused_rmsnorm=fused_rmsnorm,
eps=self.rms_norm_eps,
kernel_name="MLP",
)
residual = None
# All-reduce or reduce-scatter, depending on whether SP is enabled
if self.sequence_parallel_enabled:
output_tensor = reduce_scatter_to_sequence_parallel_region(output_tensor, self.sequence_dimension)
else:
output_tensor = reduce_from_tensor_model_parallel_region(output_tensor)
logger.debug(f"MLP output shape {output_tensor.shape}")
return (output_tensor, residual)