in optimum/neuron/models/training/llama/modeling_llama.py [0:0]
def __init__(self, config: LlamaConfig, trn_config: TrainingNeuronConfig, layer_idx: int):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads)
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.is_causal = True
if (self.hidden_size % self.num_heads) != 0:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.trn_config = trn_config
init_method = partial(_init_normal, config.initializer_range)
tp_size = get_tensor_model_parallel_size()
self.qkv_linear = (self.num_key_value_heads < tp_size) or (self.num_key_value_heads % tp_size != 0)
if self.qkv_linear:
if trn_config.kv_size_multiplier is None:
self.kv_size_multiplier = trn_config.auto_kv_size_multiplier(self.num_key_value_heads)
else:
self.kv_size_multiplier = trn_config.kv_size_multiplier
else:
self.kv_size_multiplier = 1
self.specs = ModelWeightTransformationSpecs()
if self.qkv_linear:
self.qkv_proj = GQAQKVColumnParallelLinear(
self.hidden_size,
[self.num_heads * self.head_dim, self.num_key_value_heads * self.head_dim],
bias=False,
gather_output=False,
init_method=init_method,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
kv_size_multiplier=self.kv_size_multiplier,
fuse_qkv=trn_config.fuse_qkv,
dtype=self.config.torch_dtype,
)
gqa_qkv_specs = GQAQKVColumnParallelLinearSpec(
gqa_qkv_projection_name="qkv_proj",
query_projection_name="q_proj",
key_projection_name="k_proj",
value_projection_name="v_proj",
output_projection_name="o_proj",
num_attention_heads=self.num_heads,
num_key_value_heads=self.num_key_value_heads,
kv_size_multiplier=self.kv_size_multiplier,
q_output_size_per_partition=self.qkv_proj.q_output_size_per_partition,
kv_output_size_per_partition=self.qkv_proj.kv_output_size_per_partition,
fuse_qkv=trn_config.fuse_qkv,
bias=False,
)
self.specs.add_spec(gqa_qkv_specs)
elif trn_config.fuse_qkv and self.num_heads == self.num_key_value_heads:
self.qkv_proj = ColumnParallelLinear(
self.hidden_size,
3 * self.num_heads * self.head_dim,
stride=3,
bias=False,
gather_output=False,
init_method=init_method,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
sequence_dimension=0,
dtype=self.config.torch_dtype,
)
self.specs.add_spec(
FusedLinearsSpec(
fused_linear_name="qkv_proj",
linear_names=["q_proj", "k_proj", "v_proj"],
bias=False,
fuse_axis="column",
original_dims=[self.num_heads * self.head_dim] * 3,
)
)
self.split_size = self.num_heads * self.head_dim // tp_size
else:
self.q_proj = ColumnParallelLinear(
self.hidden_size,
self.num_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=init_method,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
sequence_dimension=0,
dtype=self.config.torch_dtype,
)
self.k_proj = ColumnParallelLinear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=init_method,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
sequence_dimension=0,
dtype=self.config.torch_dtype,
)
self.v_proj = ColumnParallelLinear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=False,
gather_output=False,
init_method=init_method,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
sequence_dimension=0,
dtype=self.config.torch_dtype,
)
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim,
self.hidden_size,
bias=False,
input_is_parallel=True,
init_method=init_method,
sequence_parallel_enabled=trn_config.sequence_parallel_enabled,
sequence_dimension=0,
dtype=self.config.torch_dtype,
)
self.num_heads = neuronx_dist_utils.divide(config.num_attention_heads, tp_size)
self.num_key_value_heads = neuronx_dist_utils.divide(
config.num_key_value_heads * self.kv_size_multiplier, tp_size
)
self.num_key_value_groups = self.num_heads // self.num_key_value_heads