optimum/neuron/models/inference/backend/modules/decoder/modeling_decoder.py (759 lines of code) (raw):
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
import os
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import List, Optional, Tuple, Union
import neuronx_distributed as nxd
import torch
import torch_xla.core.xla_model as xm
from huggingface_hub import HfApi, snapshot_download
from neuronx_distributed.operators.argmax import argmax as nxd_argmax
from neuronx_distributed.parallel_layers.layers import SPMDRank
from neuronx_distributed.parallel_layers.mappings import (
_gather_along_dim,
_reduce_scatter_along_dim,
gather_from_sequence_parallel_region,
)
from torch import nn
from transformers import AutoConfig, PretrainedConfig
from transformers.modeling_outputs import CausalLMOutputWithPast
from ......cache.entries.single_model import SingleModelCacheEntry
from ......cache.hub_cache import hub_neuronx_cache
from ......modeling_decoder import NeuronModelForCausalLM
from ...config import NxDNeuronConfig
from ...pretrained_model import NxDPreTrainedModel
from ...utils.random import set_random_seed
from ..attention import utils as attn_utils
from ..autobucketing import generate_buckets
from ..flashdecode.utils import (
get_cache_size,
mask_util,
turn_2d_mask_to_4d,
)
from ..generation.generation_utils import NxDGenerationMixin
from ..generation.sampling import (
Sampler,
mask_padded_logits,
prepare_sampling_params,
validate_sampling_params,
)
from ..kvcache.kv_cache_manager import (
KVCacheManager,
_slice_kv_cacheline,
)
from .decoder_wrapper import (
CONTEXT_ENCODING_MODEL_TAG,
SPECULATION_MODEL_TAG,
TOKEN_GENERATION_MODEL_TAG,
NxDDecoderWrapper,
)
logger = logging.getLogger("Neuron")
class NxDDecoderModel(nn.Module):
"""
Base model that NeuronXXXModel classes inherit from.
The forward() function will be traced and compiled by NxD.
"""
def __init__(self, config: PretrainedConfig, neuron_config: NxDNeuronConfig):
super().__init__()
self.config = config
self.sampler = None
self.kv_mgr = None
self.neuron_config = neuron_config
self.batch_size = neuron_config.batch_size
self.n_positions = neuron_config.sequence_length
self.vocab_size = config.vocab_size
self.speculation_length = neuron_config.speculation_length
self.padding_side = neuron_config.padding_side
self.max_length = neuron_config.sequence_length
self.sequence_parallel_enabled = neuron_config.sequence_parallel_enabled
self.sequence_dimension = 1 if self.sequence_parallel_enabled else None
self.rank_util = SPMDRank(world_size=neuron_config.tp_degree)
self.num_cores_per_group = neuron_config.num_cores_per_group
if neuron_config.on_device_sampling:
# Instantiate a multinomial Sampler (it can still be used for greedy by passing topk=1)
self.sampler = Sampler(neuron_config, do_sample=True)
self.kv_mgr = KVCacheManager(config, neuron_config, num_kv_head=config.num_key_value_heads)
def initialize_process_group(self, seed: int = 0):
if not torch.dist.is_initialized():
torch.dist.init_process_group(backend="xla")
else:
logging.warning("torch.distributed was already initialized, skipping...")
if not nxd.parallel_layers.parallel_state.model_parallel_is_initialized():
nxd.parallel_layers.initialize_model_parallel(
tensor_model_parallel_size=self.neuron_config.tp_degree,
pipeline_model_parallel_size=self.neuron_config.pp_degree,
expert_model_parallel_size=self.neuron_config.ep_degree,
)
else:
logging.warning("NxD was already initialized, skipping...")
# set seed
set_random_seed(seed)
def _is_context_encoding(self, input_ids: torch.Tensor):
return input_ids.shape[-1] > 1 and input_ids.shape[-1] != self.speculation_length
def _is_for_speculation(self, input_ids: torch.Tensor):
return input_ids.shape[-1] == self.speculation_length
def _create_context_attn_mask(self, attention_mask, **kwargs):
# Block diagonal causal mask for chunked prefill
if self.neuron_config.is_chunked_prefill:
return self._create_chunked_prefill_attn_mask(**kwargs)
# Lower triangle causal mask for classic attention
mask = torch.full((self.n_positions, self.n_positions), True, device=attention_mask.device).tril(diagonal=0)
mask = mask[None, None, :, :].expand(self.batch_size, 1, self.n_positions, self.n_positions)
if self.padding_side == "right":
return mask
else:
expanded_mask = (
attention_mask[:, None, None, :]
.expand(self.batch_size, 1, self.n_positions, self.n_positions)
.to(torch.bool)
)
return torch.logical_and(mask, expanded_mask)
def _create_chunked_prefill_attn_mask(
self,
query_lens: torch.Tensor,
key_lens: torch.Tensor,
max_query_len: int,
max_key_len: int,
**kwargs,
) -> torch.Tensor:
return attn_utils.create_block_diagonal_attn_mask(query_lens, key_lens, max_query_len, max_key_len, **kwargs)
def _create_spec_attn_mask(self, attention_mask):
return (
attention_mask[:, None, None, :]
.expand(self.batch_size, 1, self.speculation_length, self.n_positions)
.to(torch.bool)
)
def _create_simple_attn_mask(self, attention_mask):
return attention_mask[:, None, None, :].expand(self.batch_size, 1, 1, self.n_positions).to(torch.bool)
def create_attn_mask(self, attention_mask, is_for_context_encoding, is_for_speculation, **kwargs):
if is_for_context_encoding:
return self._create_context_attn_mask(attention_mask, **kwargs)
elif is_for_speculation:
return self._create_spec_attn_mask(attention_mask)
else:
return self._create_simple_attn_mask(attention_mask)
def _slice_kv_cache(self, kv_cache, n_positions):
past_key_values = []
for idx in range(len(kv_cache)):
k_cache = _slice_kv_cacheline(self.neuron_config.padding_side, n_positions, kv_cache[idx][0])
v_cache = _slice_kv_cacheline(self.neuron_config.padding_side, n_positions, kv_cache[idx][1])
past_key_values.append([k_cache, v_cache])
return past_key_values
def _is_reorder_needed(self, is_for_context_encoding, is_for_speculation):
return not is_for_context_encoding and not is_for_speculation and self.neuron_config.continuous_batching
def forward(
self,
input_ids,
attention_mask,
position_ids,
seq_ids,
sampling_params,
scatter_index=None,
# In llava context encoding model, input_embeds is precomputed
inputs_embeds: Optional[torch.FloatTensor] = None,
kv_cache: Optional[torch.Tensor] = None,
):
is_for_context_encoding = self._is_context_encoding(input_ids)
is_for_speculation = self._is_for_speculation(input_ids)
cache_size = (
get_cache_size(self.n_positions, self.num_cores_per_group, is_for_context_encoding)
if self.neuron_config.flash_decoding_enabled
else self.n_positions
)
# It is either for context encoding or for token generation
if is_for_context_encoding:
past_key_values = None
else:
if kv_cache is None:
past_key_values = self.kv_mgr.get_cache(cache_size)
else:
past_key_values = self._slice_kv_cache(kv_cache, cache_size)
# Prepare attention mask(s)
attention_mask = self.create_attn_mask(
attention_mask,
is_for_context_encoding,
is_for_speculation,
)
active_mask = None
if is_for_speculation:
active_mask = torch.full(
(self.speculation_length, self.speculation_length),
True,
device=attention_mask.device,
).tril(diagonal=0)
active_mask = active_mask[None, None, :, :].expand(
self.batch_size, 1, self.speculation_length, self.speculation_length
)
# FD masks
active_mask_2d = None
if self.neuron_config.flash_decoding_enabled and not is_for_context_encoding:
rank_id = self.rank_util.get_rank()
active_mask_2d, attention_mask_2d = mask_util(
pos_ids=position_ids,
rank_id=rank_id,
num_cores_per_group=self.num_cores_per_group,
cache_size=cache_size,
)
active_mask = turn_2d_mask_to_4d(active_mask_2d, n_positions=1, batch_size=self.batch_size)
attention_mask = turn_2d_mask_to_4d(attention_mask_2d, n_positions=cache_size, batch_size=self.batch_size)
hidden_states, past_key_values = self.get_model_output(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
active_mask=active_mask,
inputs_embeds=inputs_embeds,
)
updated_kv_cache = self.kv_mgr.update_cache(
is_for_context_encoding=is_for_context_encoding,
seq_ids=seq_ids,
position_ids=position_ids,
new_key_values=past_key_values,
seq_len=cache_size,
scatter_index=scatter_index,
active_mask=active_mask_2d,
kvcache_buffer=kv_cache,
)
batch_size, num_tokens, hidden_size = hidden_states.shape
if self.padding_side == "left":
index = torch.tensor([num_tokens - 1], device=hidden_states.device)
index = index.unsqueeze(1).expand(batch_size, 1, hidden_size)
hidden_states = torch.gather(hidden_states, dim=1, index=index)
else:
if not (position_ids.shape[-1] == self.speculation_length or position_ids.shape[-1] == 1):
# context encoding
index = torch.max(position_ids, dim=1, keepdim=True).indices
index = index.unsqueeze(1).expand(batch_size, 1, hidden_size)
hidden_states = torch.gather(hidden_states, dim=1, index=index)
logits = self.lm_head(hidden_states)
logits = logits.float()
if hasattr(self.lm_head, "pad_size"):
if self.lm_head.gather_output:
rank_id = torch.tensor(0, device=logits.device, dtype=torch.int32)
world_size = 1
else:
rank_id = self.rank_util.get_rank()
world_size = torch.distributed.get_world_size(group=self.lm_head.tensor_parallel_group)
logits = mask_padded_logits(logits, rank_id, world_size, pad_size=self.lm_head.pad_size)
res = logits
if self.neuron_config.on_device_sampling:
# perform sampling on Neuron to get tokens
# FIXME, logits[:, -1, :] is not correct for speculation model, this is a tempory fix.
if is_for_speculation:
res = nxd_argmax(tensor=logits, dim=2, gather_dim=2, keepdim=False)
else:
res = self.sampler(logits[:, -1, :], sampling_params)
outputs = [res]
if self.neuron_config.output_logits:
logits = _gather_along_dim(
logits,
partition_dim=2,
)
outputs += [logits]
outputs += updated_kv_cache
return outputs
def get_input_embeddings(self):
return self.embed_tokens
def set_input_embeddings(self, value):
self.embed_tokens = value
def get_model_output(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
active_mask: Optional[List[torch.FloatTensor]] = None,
# In llava context encoding model, input_embeds is precomputed
inputs_embeds: Optional[torch.FloatTensor] = None,
):
batch_size, seq_length = input_ids.shape[:2]
past_key_values_length = 0
if past_key_values is not None:
past_key_values_length = past_key_values[0][0].shape[2]
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device # noqa
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()
# NeuronLlamaModel class manages the KV cache. So the attention_mask will be generated and passed
# through to LlamaModel. We override the HF's code that generates attention mask because HF does
# not support left aligned RHS padding. This enables Neuron to achieve higher performance and
# extensibility.
#
# 4d mask is passed through the layers
# attention_mask = _prepare_4d_causal_attention_mask(
# attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
# )
# embed positions
if self.sequence_parallel_enabled:
# TODO: Replace this with rankid + scatter call once supported
hidden_states = _reduce_scatter_along_dim(
inputs_embeds,
self.sequence_dimension,
xm.REDUCE_MAX,
)
else:
hidden_states = inputs_embeds
# decoder layers
next_decoder_cache = ()
cos_cache = None
sin_cache = None
for idx, decoder_layer in enumerate(self.layers):
past_key_value = past_key_values[idx] if past_key_values is not None else None
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
active_mask=active_mask,
cos_cache=cos_cache,
sin_cache=sin_cache,
)
hidden_states = layer_outputs[0]
next_decoder_cache += (layer_outputs[1],)
cos_cache, sin_cache = layer_outputs[2:]
hidden_states = self.norm(hidden_states)
if self.sequence_parallel_enabled:
hidden_states = gather_from_sequence_parallel_region(
hidden_states,
self.sequence_dimension,
)
return (hidden_states, next_decoder_cache)
class NxDModelForCausalLM(NxDGenerationMixin, NxDPreTrainedModel, NeuronModelForCausalLM):
_model_cls = None
def __init__(
self,
config: PretrainedConfig,
neuron_config: NxDNeuronConfig,
traced_model: torch.jit.ScriptModule,
context_encoding_model: NxDDecoderWrapper,
token_generation_model: NxDDecoderWrapper = None,
speculation_model: NxDDecoderWrapper = None,
):
self.context_encoding_model = context_encoding_model
self.token_generation_model = token_generation_model
self.speculation_model = speculation_model
# Model wrappers are used by the parent class to assign weights to the model.
model_wrappers = [self.context_encoding_model]
if self.token_generation_model is not None:
model_wrappers.append(self.token_generation_model)
if self.speculation_model is not None:
model_wrappers.append(self.speculation_model)
super().__init__(
config=config, neuron_config=neuron_config, traced_model=traced_model, model_wrappers=model_wrappers
)
self.text_config = self.config.get_text_config()
self.vocab_size = self.text_config.vocab_size
self.padding_side = self.neuron_config.padding_side
self.kv_cache_populated = False
# async related
self.async_mode = self.neuron_config.async_mode
self.next_cpu_inputs = None
self.prior_outputs = None
self.unequal_batching = self.neuron_config.ctx_batch_size != self.neuron_config.tkg_batch_size
if self.async_mode:
os.environ["NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS"] = "2"
self.sampler = None
@staticmethod
def create_context_encoding_wrapper(model_cls, config, neuron_config, **model_init_kwargs):
new_neuron_config = copy.deepcopy(neuron_config)
new_neuron_config.batch_size = neuron_config.ctx_batch_size
new_neuron_config.n_active_tokens = neuron_config.max_context_length
if new_neuron_config.enable_bucketing:
buckets = generate_buckets(128, new_neuron_config.max_context_length)
else:
buckets = generate_buckets(
new_neuron_config.max_context_length,
new_neuron_config.max_context_length,
)
return NxDDecoderWrapper(
config=config,
neuron_config=new_neuron_config,
buckets=buckets,
bucket_n_active_tokens=True,
model_cls=model_cls,
tag=CONTEXT_ENCODING_MODEL_TAG,
model_init_kwargs=model_init_kwargs,
)
@staticmethod
def create_token_generation_wrapper(
model_cls, config, neuron_config, enable_wlt_optimization: bool = True, **model_init_kwargs
):
new_neuron_config = copy.deepcopy(neuron_config)
new_neuron_config.batch_size = neuron_config.tkg_batch_size
new_neuron_config.n_active_tokens = 1
new_neuron_config.sequence_parallel_enabled = False
if new_neuron_config.enable_bucketing:
buckets = generate_buckets(128, neuron_config.sequence_length)
else:
buckets = generate_buckets(neuron_config.sequence_length, neuron_config.sequence_length)
# shouldn't be used in token gen models
new_neuron_config.sequence_parallel_enabled = False
return NxDDecoderWrapper(
config=config,
neuron_config=new_neuron_config,
buckets=buckets,
bucket_n_active_tokens=False,
model_cls=model_cls,
tag=TOKEN_GENERATION_MODEL_TAG,
priority_model_idx=0 if enable_wlt_optimization else None, # to turn on weight layout optimization
model_init_kwargs=model_init_kwargs,
)
@staticmethod
def create_speculation_wrapper(model_cls, config, neuron_config, **model_init_kwargs):
new_neuron_config = copy.deepcopy(neuron_config)
new_neuron_config.batch_size = neuron_config.tkg_batch_size
new_neuron_config.n_active_tokens = neuron_config.speculation_length
new_neuron_config.sequence_parallel_enabled = False
if new_neuron_config.enable_bucketing:
buckets = generate_buckets(128, neuron_config.sequence_length)
else:
buckets = generate_buckets(neuron_config.sequence_length, neuron_config.sequence_length)
return NxDDecoderWrapper(
config=config,
neuron_config=new_neuron_config,
buckets=buckets,
bucket_n_active_tokens=False,
model_cls=model_cls,
tag=SPECULATION_MODEL_TAG,
priority_model_idx=0, # to turn on weight layout optimization
model_init_kwargs=model_init_kwargs,
)
@staticmethod
def create_model_wrappers(model_cls, config, neuron_config, **model_init_kwargs):
context_encoding_model = NxDModelForCausalLM.create_context_encoding_wrapper(
model_cls,
config,
neuron_config,
**model_init_kwargs,
)
token_generation_model = NxDModelForCausalLM.create_token_generation_wrapper(
model_cls,
config,
neuron_config,
**model_init_kwargs,
)
speculation_model = (
NxDModelForCausalLM.create_speculation_wrapper(
model_cls,
config,
neuron_config,
**model_init_kwargs,
)
if neuron_config.speculation_length > 0
else None
)
return context_encoding_model, token_generation_model, speculation_model
def forward(
self,
input_ids: torch.LongTensor,
position_ids: Optional[torch.LongTensor],
seq_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sampling_params: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
if self.async_mode:
# derive future cpu inputs from current cpu inputs
if position_ids.shape[1] == input_ids.shape[1]:
next_position_ids = torch.amax(position_ids, 1, keepdim=True)
else:
next_position_ids = position_ids
next_position_ids = next_position_ids + 1
next_attention_mask = self._infer_attention_mask(next_position_ids)
self.next_cpu_inputs = {
"attention_mask": next_attention_mask,
"position_ids": next_position_ids,
}
# infer attention_mask from position_ids if not provided
if attention_mask is None:
attention_mask = self._infer_attention_mask(position_ids)
if seq_ids is None:
seq_ids = torch.arange(input_ids.shape[0])
if sampling_params is None:
if self.neuron_config.on_device_sampling:
raise ValueError("The sampling params tensor is required for on-device sampling.")
# Just pass a dummy tensor to the model, it will be ignored
sampling_params = prepare_sampling_params(seq_ids.shape[0])
elif self.neuron_config.on_device_sampling:
validate_sampling_params(sampling_params, self.neuron_config.max_topk)
output_attentions, output_hidden_states, return_dict = self._setup_func_config(
output_attentions, output_hidden_states, return_dict
)
logits_or_next_tokens = self._get_model_outputs(
input_ids,
attention_mask,
position_ids,
seq_ids,
sampling_params,
)
logging.debug("---output---")
logging.debug(
f"{'tokens' if self.neuron_config.on_device_sampling else 'logits'} = %s, ",
logits_or_next_tokens,
)
return self._construct_output(logits_or_next_tokens)
def _setup_func_config(self, output_attentions, output_hidden_states, return_dict):
output_attentions = output_attentions if output_attentions is not None else self.text_config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.text_config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else getattr(self.config, "use_return_dict", None)
return output_attentions, output_hidden_states, return_dict
def _infer_attention_mask(self, position_ids):
assert position_ids is not None, "need to call forward with position_ids if attention_mask is not provided"
batch_size, seq_len = position_ids.shape
if position_ids.shape[-1] == 1:
seq_len = self.neuron_config.sequence_length
position_ids_to_compare = position_ids.expand(batch_size, seq_len) - 1
else:
seq_len = position_ids.shape[-1]
position_ids_to_compare = position_ids
mask = torch.arange(seq_len).view(1, -1).expand(batch_size, seq_len)
attention_mask = (position_ids_to_compare >= mask).to(dtype=position_ids.dtype)
return attention_mask
def _get_async_output(
self,
ranked_async_tensor,
):
outputs = [[async_tensor[0].cpu()] for async_tensor in ranked_async_tensor]
return outputs[0][0]
def _get_model_outputs(
self,
input_ids,
attention_mask,
position_ids,
seq_ids,
sampling_params,
):
# casting inputs to int32
input_ids = input_ids.to(torch.int32)
attention_mask = attention_mask.to(torch.int32)
position_ids = position_ids.to(torch.int32)
seq_ids = seq_ids.to(torch.int32)
if input_ids.shape[-1] > 1 and not position_ids.min().item():
outputs = self.context_encoding_model(
input_ids,
attention_mask,
position_ids,
seq_ids,
sampling_params,
)
self.kv_cache_populated = True
if self.async_mode:
if not self.unequal_batching:
# for now only cte + tkg flow is supported with async (this will be enforced at config level)
next_outputs = self.token_generation_model(
outputs,
self.next_cpu_inputs["attention_mask"],
self.next_cpu_inputs["position_ids"],
seq_ids,
sampling_params,
)
outputs = self._get_async_output(outputs) # block on cte call
self.prior_outputs = next_outputs
else:
if isinstance(
outputs, list
): # in case the outputs weren't passed through `torch.cat` in model_wrapper.py
outputs = self._get_async_output(outputs) # block on cte call
self.prior_outputs = None
elif input_ids.shape[-1] == self.neuron_config.speculation_length:
outputs = self.speculation_model(
input_ids,
attention_mask,
position_ids,
seq_ids,
sampling_params,
)
else:
if (
self.next_cpu_inputs is not None and self.prior_outputs is not None
): # this is never not None and not in async mode
_input_ids = self.prior_outputs
_attention_mask = self.next_cpu_inputs["attention_mask"]
_position_ids = self.next_cpu_inputs["position_ids"]
else:
_input_ids = input_ids
_attention_mask = attention_mask
_position_ids = position_ids
next_outputs = self.token_generation_model(
_input_ids,
_attention_mask,
_position_ids,
seq_ids,
sampling_params,
)
if self.async_mode:
if self.prior_outputs is None: # this means that next_outputs is processing token to be returned
self.prior_outputs = next_outputs
next_outputs = self.token_generation_model( # submit future token request
next_outputs,
self.next_cpu_inputs["attention_mask"],
self.next_cpu_inputs["position_ids"],
seq_ids,
sampling_params,
)
outputs = self.prior_outputs
if isinstance(outputs, list):
outputs = self._get_async_output(
self.prior_outputs
) # block on prior (sometimes current) token gen request
self.prior_outputs = next_outputs
else:
outputs = next_outputs
return outputs
def _construct_output(self, logits_or_next_tokens):
next_tokens = logits_or_next_tokens
OutputParams = CausalLMOutputWithPast(
logits=None if self.neuron_config.on_device_sampling else logits_or_next_tokens,
hidden_states=logits_or_next_tokens,
attentions=None,
)
OutputParams.tokens = next_tokens
return OutputParams
def reset(self):
# We need to reset the KV cache flag for a new batch of inference.
# When the flag is reset, the subsequent run will invoke the
# context encoding model.
self.kv_cache_populated = False
def get_required_kwargs(self) -> List[str]:
"""The list of required kwargs to the model's forward"""
return []
@classmethod
def get_compiler_args(cls, neuron_config: NxDNeuronConfig) -> str:
tensorizer_options = (
"--enable-ccop-compute-overlap "
f"--cc-pipeline-tiling-factor={neuron_config.cc_pipeline_tiling_factor} "
"--vectorize-dge-dma "
"--vectorize-strided-dma "
)
compiler_args = (
"--auto-cast=none --model-type=transformer "
f"--tensorizer-options='{tensorizer_options}'"
" -O2 "
f" --internal-num-neuroncores-per-sengine={neuron_config.logical_nc_config}"
)
if neuron_config.target:
compiler_args += f" --target {neuron_config.target}"
logging.info(f"neuronx-cc compiler_args are: {compiler_args}")
return compiler_args
# NeuronModelForCausalLM methods
@classmethod
def _from_pretrained(
cls,
model_id: Union[str, "Path"],
config: "PretrainedConfig",
revision: Optional[str] = None,
token: Optional[Union[bool, str]] = None,
cache_dir: Optional[str] = None,
force_download: Optional[bool] = False,
subfolder: Optional[str] = "",
local_files_only: Optional[bool] = False,
trust_remote_code: Optional[bool] = False,
**kwargs,
) -> "NeuronModelForCausalLM":
if len(kwargs) > 0:
logger.warning("Ignoring the following kwargs as they are not supported by neuron: %s", kwargs.keys())
neuron_config = cls.get_neuron_config_cls().from_pretrained(model_id)
context_encoding_model, token_generation_model, speculation_model = cls.create_model_wrappers(
model_cls=cls._model_cls,
config=config,
neuron_config=neuron_config,
)
if not os.path.exists(model_id):
# The model_id is a model hub id: download the model from the hub.
with TemporaryDirectory() as tmpdir:
snapshot_download(
repo_id=model_id,
revision=revision,
cache_dir=cache_dir,
local_dir=tmpdir,
force_download=force_download,
local_files_only=local_files_only,
token=token,
allow_patterns=[cls.COMPILED_MODEL_FILE_NAME],
)
traced_model = torch.jit.load(os.path.join(tmpdir, cls.COMPILED_MODEL_FILE_NAME))
else:
traced_model = torch.jit.load(os.path.join(model_id, cls.COMPILED_MODEL_FILE_NAME))
model = cls(
config=config,
neuron_config=neuron_config,
traced_model=traced_model,
context_encoding_model=context_encoding_model,
token_generation_model=token_generation_model,
speculation_model=speculation_model,
)
model.load_weights(
model_id,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
token=token,
)
return model
@classmethod
def export(
cls,
model_id: str,
config: "PretrainedConfig",
neuron_config: "NxDNeuronConfig",
token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
cache_dir: Optional[str] = None,
force_download: Optional[bool] = False,
subfolder: Optional[str] = "",
local_files_only: Optional[bool] = False,
trust_remote_code: Optional[bool] = False,
load_weights: bool = True,
**kwargs,
) -> "NeuronModelForCausalLM":
if len(kwargs) > 0:
logger.warning("Ignoring the following kwargs as they are not supported by neuron: %s", kwargs.keys())
config = AutoConfig.from_pretrained(
model_id,
token=token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
trust_remote_code=trust_remote_code,
)
# Override torch_dtype in config as it is used by the neuronx_distributed code to cast weights to the correct type
config.torch_dtype = neuron_config.torch_dtype
context_encoding_model, token_generation_model, speculation_model = cls.create_model_wrappers(
model_cls=cls._model_cls,
config=config,
neuron_config=neuron_config,
)
model_wrappers = []
for wrapper in context_encoding_model, token_generation_model, speculation_model:
if wrapper is not None:
model_wrappers.append(wrapper)
# The model NEFF files will be cached locally, but if the model_id corresponds
# to a hub model, we also create a cache entry for it.
cache_entry = (
None
if os.path.exists(model_id)
else SingleModelCacheEntry(model_id, task="text-generation", config=config, neuron_config=neuron_config)
)
with hub_neuronx_cache(entry=cache_entry):
traced_model = NxDPreTrainedModel.compile(
neuron_config=neuron_config,
model_wrappers=model_wrappers,
compiler_args=cls.get_compiler_args(neuron_config),
)
model = cls(
config=config,
neuron_config=neuron_config,
traced_model=traced_model,
context_encoding_model=context_encoding_model,
token_generation_model=token_generation_model,
speculation_model=speculation_model,
)
if load_weights:
model.load_weights(
model_id,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
token=token,
)
return model
def _save_pretrained(self, save_directory: Union[str, Path]):
model_name_or_path = getattr(self.config, "_name_or_path")
# If the model was exported from a local path, we need to save the checkpoint (not that we also shard it)
weight_path = model_name_or_path if os.path.isdir(model_name_or_path) else None
self.save(save_directory, weight_path=weight_path)
def push_to_hub(
self,
save_directory: str,
repository_id: str,
private: Optional[bool] = None,
revision: Optional[str] = None,
token: Union[bool, str] = True,
endpoint: Optional[str] = None,
) -> str:
api = HfApi(endpoint=endpoint)
api.create_repo(
token=token,
repo_id=repository_id,
exist_ok=True,
private=private,
)
ignore_patterns = []
checkpoint_id = self.neuron_config.checkpoint_id
if checkpoint_id is not None:
# Avoid uploading checkpoints when the original model is available on the hub
ignore_patterns = [self.CHECKPOINT_DIR + "/*"]
api.upload_folder(
repo_id=repository_id,
folder_path=save_directory,
token=token,
revision=revision,
ignore_patterns=ignore_patterns,
)