in optimum/neuron/models/inference/llama/modeling_llama.py [0:0]
def __init__(self, config: LlamaConfig, neuron_config: NxDNeuronConfig):
super().__init__()
self.tp_degree = neuron_config.tp_degree
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.act_fn = ACT2FN[config.hidden_act]
self.sequence_parallel_enabled = getattr(neuron_config, "sequence_parallel_enabled", False)
self.sequence_dimension = 1 if self.sequence_parallel_enabled else None
self.rms_norm_eps = config.rms_norm_eps
self.mlp_kernel_enabled = neuron_config.mlp_kernel_enabled
self.logical_nc_config = neuron_config.logical_nc_config
mlp_bias = getattr(config, "mlp_bias", False)
self.gate_proj = ColumnParallelLinear(
self.hidden_size,
self.intermediate_size,
bias=mlp_bias,
gather_output=False,
dtype=neuron_config.torch_dtype,
pad=True,
sequence_parallel_enabled=False,
sequence_dimension=None,
)
self.up_proj = ColumnParallelLinear(
self.hidden_size,
self.intermediate_size,
bias=mlp_bias,
gather_output=False,
dtype=neuron_config.torch_dtype,
pad=True,
sequence_parallel_enabled=False,
sequence_dimension=None,
)
self.down_proj = RowParallelLinear(
self.intermediate_size,
self.hidden_size,
bias=mlp_bias,
input_is_parallel=True,
dtype=neuron_config.torch_dtype,
pad=True,
sequence_parallel_enabled=self.sequence_parallel_enabled,
sequence_dimension=self.sequence_dimension,
reduce_dtype=neuron_config.rpl_reduce_dtype,
)
if self.mlp_kernel_enabled:
# Transpose the weights to the layout expected by kernels
self.gate_proj.weight = transpose_parallel_linear_layer(self.gate_proj.weight)
self.up_proj.weight = transpose_parallel_linear_layer(self.up_proj.weight)
self.down_proj.weight = transpose_parallel_linear_layer(self.down_proj.weight)