optimum/neuron/models/inference/backend/modules/decoder/modeling_decoder.py (759 lines of code) (raw):

# coding=utf-8 # Copyright 2025 The HuggingFace Inc. team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import copy import logging import os from pathlib import Path from tempfile import TemporaryDirectory from typing import List, Optional, Tuple, Union import neuronx_distributed as nxd import torch import torch_xla.core.xla_model as xm from huggingface_hub import HfApi, snapshot_download from neuronx_distributed.operators.argmax import argmax as nxd_argmax from neuronx_distributed.parallel_layers.layers import SPMDRank from neuronx_distributed.parallel_layers.mappings import ( _gather_along_dim, _reduce_scatter_along_dim, gather_from_sequence_parallel_region, ) from torch import nn from transformers import AutoConfig, PretrainedConfig from transformers.modeling_outputs import CausalLMOutputWithPast from ......cache.entries.single_model import SingleModelCacheEntry from ......cache.hub_cache import hub_neuronx_cache from ......modeling_decoder import NeuronModelForCausalLM from ...config import NxDNeuronConfig from ...pretrained_model import NxDPreTrainedModel from ...utils.random import set_random_seed from ..attention import utils as attn_utils from ..autobucketing import generate_buckets from ..flashdecode.utils import ( get_cache_size, mask_util, turn_2d_mask_to_4d, ) from ..generation.generation_utils import NxDGenerationMixin from ..generation.sampling import ( Sampler, mask_padded_logits, prepare_sampling_params, validate_sampling_params, ) from ..kvcache.kv_cache_manager import ( KVCacheManager, _slice_kv_cacheline, ) from .decoder_wrapper import ( CONTEXT_ENCODING_MODEL_TAG, SPECULATION_MODEL_TAG, TOKEN_GENERATION_MODEL_TAG, NxDDecoderWrapper, ) logger = logging.getLogger("Neuron") class NxDDecoderModel(nn.Module): """ Base model that NeuronXXXModel classes inherit from. The forward() function will be traced and compiled by NxD. """ def __init__(self, config: PretrainedConfig, neuron_config: NxDNeuronConfig): super().__init__() self.config = config self.sampler = None self.kv_mgr = None self.neuron_config = neuron_config self.batch_size = neuron_config.batch_size self.n_positions = neuron_config.sequence_length self.vocab_size = config.vocab_size self.speculation_length = neuron_config.speculation_length self.padding_side = neuron_config.padding_side self.max_length = neuron_config.sequence_length self.sequence_parallel_enabled = neuron_config.sequence_parallel_enabled self.sequence_dimension = 1 if self.sequence_parallel_enabled else None self.rank_util = SPMDRank(world_size=neuron_config.tp_degree) self.num_cores_per_group = neuron_config.num_cores_per_group if neuron_config.on_device_sampling: # Instantiate a multinomial Sampler (it can still be used for greedy by passing topk=1) self.sampler = Sampler(neuron_config, do_sample=True) self.kv_mgr = KVCacheManager(config, neuron_config, num_kv_head=config.num_key_value_heads) def initialize_process_group(self, seed: int = 0): if not torch.dist.is_initialized(): torch.dist.init_process_group(backend="xla") else: logging.warning("torch.distributed was already initialized, skipping...") if not nxd.parallel_layers.parallel_state.model_parallel_is_initialized(): nxd.parallel_layers.initialize_model_parallel( tensor_model_parallel_size=self.neuron_config.tp_degree, pipeline_model_parallel_size=self.neuron_config.pp_degree, expert_model_parallel_size=self.neuron_config.ep_degree, ) else: logging.warning("NxD was already initialized, skipping...") # set seed set_random_seed(seed) def _is_context_encoding(self, input_ids: torch.Tensor): return input_ids.shape[-1] > 1 and input_ids.shape[-1] != self.speculation_length def _is_for_speculation(self, input_ids: torch.Tensor): return input_ids.shape[-1] == self.speculation_length def _create_context_attn_mask(self, attention_mask, **kwargs): # Block diagonal causal mask for chunked prefill if self.neuron_config.is_chunked_prefill: return self._create_chunked_prefill_attn_mask(**kwargs) # Lower triangle causal mask for classic attention mask = torch.full((self.n_positions, self.n_positions), True, device=attention_mask.device).tril(diagonal=0) mask = mask[None, None, :, :].expand(self.batch_size, 1, self.n_positions, self.n_positions) if self.padding_side == "right": return mask else: expanded_mask = ( attention_mask[:, None, None, :] .expand(self.batch_size, 1, self.n_positions, self.n_positions) .to(torch.bool) ) return torch.logical_and(mask, expanded_mask) def _create_chunked_prefill_attn_mask( self, query_lens: torch.Tensor, key_lens: torch.Tensor, max_query_len: int, max_key_len: int, **kwargs, ) -> torch.Tensor: return attn_utils.create_block_diagonal_attn_mask(query_lens, key_lens, max_query_len, max_key_len, **kwargs) def _create_spec_attn_mask(self, attention_mask): return ( attention_mask[:, None, None, :] .expand(self.batch_size, 1, self.speculation_length, self.n_positions) .to(torch.bool) ) def _create_simple_attn_mask(self, attention_mask): return attention_mask[:, None, None, :].expand(self.batch_size, 1, 1, self.n_positions).to(torch.bool) def create_attn_mask(self, attention_mask, is_for_context_encoding, is_for_speculation, **kwargs): if is_for_context_encoding: return self._create_context_attn_mask(attention_mask, **kwargs) elif is_for_speculation: return self._create_spec_attn_mask(attention_mask) else: return self._create_simple_attn_mask(attention_mask) def _slice_kv_cache(self, kv_cache, n_positions): past_key_values = [] for idx in range(len(kv_cache)): k_cache = _slice_kv_cacheline(self.neuron_config.padding_side, n_positions, kv_cache[idx][0]) v_cache = _slice_kv_cacheline(self.neuron_config.padding_side, n_positions, kv_cache[idx][1]) past_key_values.append([k_cache, v_cache]) return past_key_values def _is_reorder_needed(self, is_for_context_encoding, is_for_speculation): return not is_for_context_encoding and not is_for_speculation and self.neuron_config.continuous_batching def forward( self, input_ids, attention_mask, position_ids, seq_ids, sampling_params, scatter_index=None, # In llava context encoding model, input_embeds is precomputed inputs_embeds: Optional[torch.FloatTensor] = None, kv_cache: Optional[torch.Tensor] = None, ): is_for_context_encoding = self._is_context_encoding(input_ids) is_for_speculation = self._is_for_speculation(input_ids) cache_size = ( get_cache_size(self.n_positions, self.num_cores_per_group, is_for_context_encoding) if self.neuron_config.flash_decoding_enabled else self.n_positions ) # It is either for context encoding or for token generation if is_for_context_encoding: past_key_values = None else: if kv_cache is None: past_key_values = self.kv_mgr.get_cache(cache_size) else: past_key_values = self._slice_kv_cache(kv_cache, cache_size) # Prepare attention mask(s) attention_mask = self.create_attn_mask( attention_mask, is_for_context_encoding, is_for_speculation, ) active_mask = None if is_for_speculation: active_mask = torch.full( (self.speculation_length, self.speculation_length), True, device=attention_mask.device, ).tril(diagonal=0) active_mask = active_mask[None, None, :, :].expand( self.batch_size, 1, self.speculation_length, self.speculation_length ) # FD masks active_mask_2d = None if self.neuron_config.flash_decoding_enabled and not is_for_context_encoding: rank_id = self.rank_util.get_rank() active_mask_2d, attention_mask_2d = mask_util( pos_ids=position_ids, rank_id=rank_id, num_cores_per_group=self.num_cores_per_group, cache_size=cache_size, ) active_mask = turn_2d_mask_to_4d(active_mask_2d, n_positions=1, batch_size=self.batch_size) attention_mask = turn_2d_mask_to_4d(attention_mask_2d, n_positions=cache_size, batch_size=self.batch_size) hidden_states, past_key_values = self.get_model_output( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, active_mask=active_mask, inputs_embeds=inputs_embeds, ) updated_kv_cache = self.kv_mgr.update_cache( is_for_context_encoding=is_for_context_encoding, seq_ids=seq_ids, position_ids=position_ids, new_key_values=past_key_values, seq_len=cache_size, scatter_index=scatter_index, active_mask=active_mask_2d, kvcache_buffer=kv_cache, ) batch_size, num_tokens, hidden_size = hidden_states.shape if self.padding_side == "left": index = torch.tensor([num_tokens - 1], device=hidden_states.device) index = index.unsqueeze(1).expand(batch_size, 1, hidden_size) hidden_states = torch.gather(hidden_states, dim=1, index=index) else: if not (position_ids.shape[-1] == self.speculation_length or position_ids.shape[-1] == 1): # context encoding index = torch.max(position_ids, dim=1, keepdim=True).indices index = index.unsqueeze(1).expand(batch_size, 1, hidden_size) hidden_states = torch.gather(hidden_states, dim=1, index=index) logits = self.lm_head(hidden_states) logits = logits.float() if hasattr(self.lm_head, "pad_size"): if self.lm_head.gather_output: rank_id = torch.tensor(0, device=logits.device, dtype=torch.int32) world_size = 1 else: rank_id = self.rank_util.get_rank() world_size = torch.distributed.get_world_size(group=self.lm_head.tensor_parallel_group) logits = mask_padded_logits(logits, rank_id, world_size, pad_size=self.lm_head.pad_size) res = logits if self.neuron_config.on_device_sampling: # perform sampling on Neuron to get tokens # FIXME, logits[:, -1, :] is not correct for speculation model, this is a tempory fix. if is_for_speculation: res = nxd_argmax(tensor=logits, dim=2, gather_dim=2, keepdim=False) else: res = self.sampler(logits[:, -1, :], sampling_params) outputs = [res] if self.neuron_config.output_logits: logits = _gather_along_dim( logits, partition_dim=2, ) outputs += [logits] outputs += updated_kv_cache return outputs def get_input_embeddings(self): return self.embed_tokens def set_input_embeddings(self, value): self.embed_tokens = value def get_model_output( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, active_mask: Optional[List[torch.FloatTensor]] = None, # In llava context encoding model, input_embeds is precomputed inputs_embeds: Optional[torch.FloatTensor] = None, ): batch_size, seq_length = input_ids.shape[:2] past_key_values_length = 0 if past_key_values is not None: past_key_values_length = past_key_values[0][0].shape[2] if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device # noqa position_ids = torch.arange( past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device, ) position_ids = position_ids.unsqueeze(0).view(-1, seq_length) else: position_ids = position_ids.view(-1, seq_length).long() # NeuronLlamaModel class manages the KV cache. So the attention_mask will be generated and passed # through to LlamaModel. We override the HF's code that generates attention mask because HF does # not support left aligned RHS padding. This enables Neuron to achieve higher performance and # extensibility. # # 4d mask is passed through the layers # attention_mask = _prepare_4d_causal_attention_mask( # attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length # ) # embed positions if self.sequence_parallel_enabled: # TODO: Replace this with rankid + scatter call once supported hidden_states = _reduce_scatter_along_dim( inputs_embeds, self.sequence_dimension, xm.REDUCE_MAX, ) else: hidden_states = inputs_embeds # decoder layers next_decoder_cache = () cos_cache = None sin_cache = None for idx, decoder_layer in enumerate(self.layers): past_key_value = past_key_values[idx] if past_key_values is not None else None layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, active_mask=active_mask, cos_cache=cos_cache, sin_cache=sin_cache, ) hidden_states = layer_outputs[0] next_decoder_cache += (layer_outputs[1],) cos_cache, sin_cache = layer_outputs[2:] hidden_states = self.norm(hidden_states) if self.sequence_parallel_enabled: hidden_states = gather_from_sequence_parallel_region( hidden_states, self.sequence_dimension, ) return (hidden_states, next_decoder_cache) class NxDModelForCausalLM(NxDGenerationMixin, NxDPreTrainedModel, NeuronModelForCausalLM): _model_cls = None def __init__( self, config: PretrainedConfig, neuron_config: NxDNeuronConfig, traced_model: torch.jit.ScriptModule, context_encoding_model: NxDDecoderWrapper, token_generation_model: NxDDecoderWrapper = None, speculation_model: NxDDecoderWrapper = None, ): self.context_encoding_model = context_encoding_model self.token_generation_model = token_generation_model self.speculation_model = speculation_model # Model wrappers are used by the parent class to assign weights to the model. model_wrappers = [self.context_encoding_model] if self.token_generation_model is not None: model_wrappers.append(self.token_generation_model) if self.speculation_model is not None: model_wrappers.append(self.speculation_model) super().__init__( config=config, neuron_config=neuron_config, traced_model=traced_model, model_wrappers=model_wrappers ) self.text_config = self.config.get_text_config() self.vocab_size = self.text_config.vocab_size self.padding_side = self.neuron_config.padding_side self.kv_cache_populated = False # async related self.async_mode = self.neuron_config.async_mode self.next_cpu_inputs = None self.prior_outputs = None self.unequal_batching = self.neuron_config.ctx_batch_size != self.neuron_config.tkg_batch_size if self.async_mode: os.environ["NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS"] = "2" self.sampler = None @staticmethod def create_context_encoding_wrapper(model_cls, config, neuron_config, **model_init_kwargs): new_neuron_config = copy.deepcopy(neuron_config) new_neuron_config.batch_size = neuron_config.ctx_batch_size new_neuron_config.n_active_tokens = neuron_config.max_context_length if new_neuron_config.enable_bucketing: buckets = generate_buckets(128, new_neuron_config.max_context_length) else: buckets = generate_buckets( new_neuron_config.max_context_length, new_neuron_config.max_context_length, ) return NxDDecoderWrapper( config=config, neuron_config=new_neuron_config, buckets=buckets, bucket_n_active_tokens=True, model_cls=model_cls, tag=CONTEXT_ENCODING_MODEL_TAG, model_init_kwargs=model_init_kwargs, ) @staticmethod def create_token_generation_wrapper( model_cls, config, neuron_config, enable_wlt_optimization: bool = True, **model_init_kwargs ): new_neuron_config = copy.deepcopy(neuron_config) new_neuron_config.batch_size = neuron_config.tkg_batch_size new_neuron_config.n_active_tokens = 1 new_neuron_config.sequence_parallel_enabled = False if new_neuron_config.enable_bucketing: buckets = generate_buckets(128, neuron_config.sequence_length) else: buckets = generate_buckets(neuron_config.sequence_length, neuron_config.sequence_length) # shouldn't be used in token gen models new_neuron_config.sequence_parallel_enabled = False return NxDDecoderWrapper( config=config, neuron_config=new_neuron_config, buckets=buckets, bucket_n_active_tokens=False, model_cls=model_cls, tag=TOKEN_GENERATION_MODEL_TAG, priority_model_idx=0 if enable_wlt_optimization else None, # to turn on weight layout optimization model_init_kwargs=model_init_kwargs, ) @staticmethod def create_speculation_wrapper(model_cls, config, neuron_config, **model_init_kwargs): new_neuron_config = copy.deepcopy(neuron_config) new_neuron_config.batch_size = neuron_config.tkg_batch_size new_neuron_config.n_active_tokens = neuron_config.speculation_length new_neuron_config.sequence_parallel_enabled = False if new_neuron_config.enable_bucketing: buckets = generate_buckets(128, neuron_config.sequence_length) else: buckets = generate_buckets(neuron_config.sequence_length, neuron_config.sequence_length) return NxDDecoderWrapper( config=config, neuron_config=new_neuron_config, buckets=buckets, bucket_n_active_tokens=False, model_cls=model_cls, tag=SPECULATION_MODEL_TAG, priority_model_idx=0, # to turn on weight layout optimization model_init_kwargs=model_init_kwargs, ) @staticmethod def create_model_wrappers(model_cls, config, neuron_config, **model_init_kwargs): context_encoding_model = NxDModelForCausalLM.create_context_encoding_wrapper( model_cls, config, neuron_config, **model_init_kwargs, ) token_generation_model = NxDModelForCausalLM.create_token_generation_wrapper( model_cls, config, neuron_config, **model_init_kwargs, ) speculation_model = ( NxDModelForCausalLM.create_speculation_wrapper( model_cls, config, neuron_config, **model_init_kwargs, ) if neuron_config.speculation_length > 0 else None ) return context_encoding_model, token_generation_model, speculation_model def forward( self, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor], seq_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, sampling_params: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: if self.async_mode: # derive future cpu inputs from current cpu inputs if position_ids.shape[1] == input_ids.shape[1]: next_position_ids = torch.amax(position_ids, 1, keepdim=True) else: next_position_ids = position_ids next_position_ids = next_position_ids + 1 next_attention_mask = self._infer_attention_mask(next_position_ids) self.next_cpu_inputs = { "attention_mask": next_attention_mask, "position_ids": next_position_ids, } # infer attention_mask from position_ids if not provided if attention_mask is None: attention_mask = self._infer_attention_mask(position_ids) if seq_ids is None: seq_ids = torch.arange(input_ids.shape[0]) if sampling_params is None: if self.neuron_config.on_device_sampling: raise ValueError("The sampling params tensor is required for on-device sampling.") # Just pass a dummy tensor to the model, it will be ignored sampling_params = prepare_sampling_params(seq_ids.shape[0]) elif self.neuron_config.on_device_sampling: validate_sampling_params(sampling_params, self.neuron_config.max_topk) output_attentions, output_hidden_states, return_dict = self._setup_func_config( output_attentions, output_hidden_states, return_dict ) logits_or_next_tokens = self._get_model_outputs( input_ids, attention_mask, position_ids, seq_ids, sampling_params, ) logging.debug("---output---") logging.debug( f"{'tokens' if self.neuron_config.on_device_sampling else 'logits'} = %s, ", logits_or_next_tokens, ) return self._construct_output(logits_or_next_tokens) def _setup_func_config(self, output_attentions, output_hidden_states, return_dict): output_attentions = output_attentions if output_attentions is not None else self.text_config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.text_config.output_hidden_states ) return_dict = return_dict if return_dict is not None else getattr(self.config, "use_return_dict", None) return output_attentions, output_hidden_states, return_dict def _infer_attention_mask(self, position_ids): assert position_ids is not None, "need to call forward with position_ids if attention_mask is not provided" batch_size, seq_len = position_ids.shape if position_ids.shape[-1] == 1: seq_len = self.neuron_config.sequence_length position_ids_to_compare = position_ids.expand(batch_size, seq_len) - 1 else: seq_len = position_ids.shape[-1] position_ids_to_compare = position_ids mask = torch.arange(seq_len).view(1, -1).expand(batch_size, seq_len) attention_mask = (position_ids_to_compare >= mask).to(dtype=position_ids.dtype) return attention_mask def _get_async_output( self, ranked_async_tensor, ): outputs = [[async_tensor[0].cpu()] for async_tensor in ranked_async_tensor] return outputs[0][0] def _get_model_outputs( self, input_ids, attention_mask, position_ids, seq_ids, sampling_params, ): # casting inputs to int32 input_ids = input_ids.to(torch.int32) attention_mask = attention_mask.to(torch.int32) position_ids = position_ids.to(torch.int32) seq_ids = seq_ids.to(torch.int32) if input_ids.shape[-1] > 1 and not position_ids.min().item(): outputs = self.context_encoding_model( input_ids, attention_mask, position_ids, seq_ids, sampling_params, ) self.kv_cache_populated = True if self.async_mode: if not self.unequal_batching: # for now only cte + tkg flow is supported with async (this will be enforced at config level) next_outputs = self.token_generation_model( outputs, self.next_cpu_inputs["attention_mask"], self.next_cpu_inputs["position_ids"], seq_ids, sampling_params, ) outputs = self._get_async_output(outputs) # block on cte call self.prior_outputs = next_outputs else: if isinstance( outputs, list ): # in case the outputs weren't passed through `torch.cat` in model_wrapper.py outputs = self._get_async_output(outputs) # block on cte call self.prior_outputs = None elif input_ids.shape[-1] == self.neuron_config.speculation_length: outputs = self.speculation_model( input_ids, attention_mask, position_ids, seq_ids, sampling_params, ) else: if ( self.next_cpu_inputs is not None and self.prior_outputs is not None ): # this is never not None and not in async mode _input_ids = self.prior_outputs _attention_mask = self.next_cpu_inputs["attention_mask"] _position_ids = self.next_cpu_inputs["position_ids"] else: _input_ids = input_ids _attention_mask = attention_mask _position_ids = position_ids next_outputs = self.token_generation_model( _input_ids, _attention_mask, _position_ids, seq_ids, sampling_params, ) if self.async_mode: if self.prior_outputs is None: # this means that next_outputs is processing token to be returned self.prior_outputs = next_outputs next_outputs = self.token_generation_model( # submit future token request next_outputs, self.next_cpu_inputs["attention_mask"], self.next_cpu_inputs["position_ids"], seq_ids, sampling_params, ) outputs = self.prior_outputs if isinstance(outputs, list): outputs = self._get_async_output( self.prior_outputs ) # block on prior (sometimes current) token gen request self.prior_outputs = next_outputs else: outputs = next_outputs return outputs def _construct_output(self, logits_or_next_tokens): next_tokens = logits_or_next_tokens OutputParams = CausalLMOutputWithPast( logits=None if self.neuron_config.on_device_sampling else logits_or_next_tokens, hidden_states=logits_or_next_tokens, attentions=None, ) OutputParams.tokens = next_tokens return OutputParams def reset(self): # We need to reset the KV cache flag for a new batch of inference. # When the flag is reset, the subsequent run will invoke the # context encoding model. self.kv_cache_populated = False def get_required_kwargs(self) -> List[str]: """The list of required kwargs to the model's forward""" return [] @classmethod def get_compiler_args(cls, neuron_config: NxDNeuronConfig) -> str: tensorizer_options = ( "--enable-ccop-compute-overlap " f"--cc-pipeline-tiling-factor={neuron_config.cc_pipeline_tiling_factor} " "--vectorize-dge-dma " "--vectorize-strided-dma " ) compiler_args = ( "--auto-cast=none --model-type=transformer " f"--tensorizer-options='{tensorizer_options}'" " -O2 " f" --internal-num-neuroncores-per-sengine={neuron_config.logical_nc_config}" ) if neuron_config.target: compiler_args += f" --target {neuron_config.target}" logging.info(f"neuronx-cc compiler_args are: {compiler_args}") return compiler_args # NeuronModelForCausalLM methods @classmethod def _from_pretrained( cls, model_id: Union[str, "Path"], config: "PretrainedConfig", revision: Optional[str] = None, token: Optional[Union[bool, str]] = None, cache_dir: Optional[str] = None, force_download: Optional[bool] = False, subfolder: Optional[str] = "", local_files_only: Optional[bool] = False, trust_remote_code: Optional[bool] = False, **kwargs, ) -> "NeuronModelForCausalLM": if len(kwargs) > 0: logger.warning("Ignoring the following kwargs as they are not supported by neuron: %s", kwargs.keys()) neuron_config = cls.get_neuron_config_cls().from_pretrained(model_id) context_encoding_model, token_generation_model, speculation_model = cls.create_model_wrappers( model_cls=cls._model_cls, config=config, neuron_config=neuron_config, ) if not os.path.exists(model_id): # The model_id is a model hub id: download the model from the hub. with TemporaryDirectory() as tmpdir: snapshot_download( repo_id=model_id, revision=revision, cache_dir=cache_dir, local_dir=tmpdir, force_download=force_download, local_files_only=local_files_only, token=token, allow_patterns=[cls.COMPILED_MODEL_FILE_NAME], ) traced_model = torch.jit.load(os.path.join(tmpdir, cls.COMPILED_MODEL_FILE_NAME)) else: traced_model = torch.jit.load(os.path.join(model_id, cls.COMPILED_MODEL_FILE_NAME)) model = cls( config=config, neuron_config=neuron_config, traced_model=traced_model, context_encoding_model=context_encoding_model, token_generation_model=token_generation_model, speculation_model=speculation_model, ) model.load_weights( model_id, cache_dir=cache_dir, force_download=force_download, local_files_only=local_files_only, token=token, ) return model @classmethod def export( cls, model_id: str, config: "PretrainedConfig", neuron_config: "NxDNeuronConfig", token: Optional[Union[bool, str]] = None, revision: Optional[str] = None, cache_dir: Optional[str] = None, force_download: Optional[bool] = False, subfolder: Optional[str] = "", local_files_only: Optional[bool] = False, trust_remote_code: Optional[bool] = False, load_weights: bool = True, **kwargs, ) -> "NeuronModelForCausalLM": if len(kwargs) > 0: logger.warning("Ignoring the following kwargs as they are not supported by neuron: %s", kwargs.keys()) config = AutoConfig.from_pretrained( model_id, token=token, revision=revision, cache_dir=cache_dir, force_download=force_download, trust_remote_code=trust_remote_code, ) # Override torch_dtype in config as it is used by the neuronx_distributed code to cast weights to the correct type config.torch_dtype = neuron_config.torch_dtype context_encoding_model, token_generation_model, speculation_model = cls.create_model_wrappers( model_cls=cls._model_cls, config=config, neuron_config=neuron_config, ) model_wrappers = [] for wrapper in context_encoding_model, token_generation_model, speculation_model: if wrapper is not None: model_wrappers.append(wrapper) # The model NEFF files will be cached locally, but if the model_id corresponds # to a hub model, we also create a cache entry for it. cache_entry = ( None if os.path.exists(model_id) else SingleModelCacheEntry(model_id, task="text-generation", config=config, neuron_config=neuron_config) ) with hub_neuronx_cache(entry=cache_entry): traced_model = NxDPreTrainedModel.compile( neuron_config=neuron_config, model_wrappers=model_wrappers, compiler_args=cls.get_compiler_args(neuron_config), ) model = cls( config=config, neuron_config=neuron_config, traced_model=traced_model, context_encoding_model=context_encoding_model, token_generation_model=token_generation_model, speculation_model=speculation_model, ) if load_weights: model.load_weights( model_id, cache_dir=cache_dir, force_download=force_download, local_files_only=local_files_only, token=token, ) return model def _save_pretrained(self, save_directory: Union[str, Path]): model_name_or_path = getattr(self.config, "_name_or_path") # If the model was exported from a local path, we need to save the checkpoint (not that we also shard it) weight_path = model_name_or_path if os.path.isdir(model_name_or_path) else None self.save(save_directory, weight_path=weight_path) def push_to_hub( self, save_directory: str, repository_id: str, private: Optional[bool] = None, revision: Optional[str] = None, token: Union[bool, str] = True, endpoint: Optional[str] = None, ) -> str: api = HfApi(endpoint=endpoint) api.create_repo( token=token, repo_id=repository_id, exist_ok=True, private=private, ) ignore_patterns = [] checkpoint_id = self.neuron_config.checkpoint_id if checkpoint_id is not None: # Avoid uploading checkpoints when the original model is available on the hub ignore_patterns = [self.CHECKPOINT_DIR + "/*"] api.upload_folder( repo_id=repository_id, folder_path=save_directory, token=token, revision=revision, ignore_patterns=ignore_patterns, )