optimum/neuron/models/inference/backend/pretrained_model.py (234 lines of code) (raw):
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import logging
import os
from functools import partial
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import List, Optional, Union
import neuronx_distributed.trace.hlo_utils as hlo_utils
import torch
from huggingface_hub import snapshot_download
from neuronx_distributed.trace.model_builder import ModelBuilder
from safetensors.torch import load_file
from transformers import AutoModelForCausalLM, PretrainedConfig
from .cache import neff_cache
from .config import NxDNeuronConfig
from .model_wrapper import NxDModelWrapper
from .modules.checkpoint import (
load_state_dict,
)
from .modules.flashdecode.utils import calculate_num_cores_per_group
logger = logging.getLogger("Neuron")
def normalize_path(path):
"""Normalize path separators and ensure path ends with a trailing slash"""
normalized = os.path.normpath(path)
return os.path.join(normalized, "")
def get_shards_path(dest_path):
return os.path.join(dest_path, "weights")
def get_builder(
neuron_config: NxDNeuronConfig,
model_wrappers: List[NxDModelWrapper],
debug: bool = False,
checkpoint_loader=None,
compiler_args: str = None,
):
"""Creates a ModelBuilder instance for the given model wrappers.
This function initializes a ModelBuilder with the specified Neuron configuration and model wrappers.
It exists to provide a convenient way to create a ModelBuilder instance, and is called by the
`NxDPreTrainedModel` class every time a model is compiled or loaded.
The returned ModelBuilder instances are typically discarded after having been used to save memory.
Args:
neuron_config (NxDNeuronConfig): The Neuron configuration.
model_wrappers (List[NxDModelWrapper]): The model wrappers to be added to the builder.
debug (bool): Whether to enable debug mode.
checkpoint_loader (callable): A function to load the model's state dictionary and weights.
compiler_args (str): Compiler arguments to be passed to the builder.
Returns:
ModelBuilder: The ModelBuilder instance.
"""
base_compile_work_dir = os.environ.get("BASE_COMPILE_WORK_DIR", "/tmp/nxd_model/")
builder = ModelBuilder(
router=None,
tp_degree=neuron_config.tp_degree,
pp_degree=neuron_config.pp_degree,
ep_degree=neuron_config.ep_degree,
world_size=neuron_config.world_size,
start_rank_id=neuron_config.start_rank_id,
local_ranks_size=neuron_config.local_ranks_size,
checkpoint_loader=checkpoint_loader,
compiler_workdir=base_compile_work_dir,
debug=debug,
num_cores_per_group=neuron_config.num_cores_per_group,
logical_nc_config=neuron_config.logical_nc_config,
weights_to_skip_layout_optimization=neuron_config.weights_to_skip_layout_optimization,
)
for model in model_wrappers:
builder.add(
key=model.tag,
model_instance=model.get_model_instance(),
example_inputs=model.input_generator(),
compiler_args=compiler_args,
bucket_config=model.get_bucket_config(),
priority_model_idx=model.priority_model_idx,
)
return builder
class NxDPreTrainedModel:
_STATE_DICT_MODEL_PREFIX = "model."
_NEW_STATE_DICT_MODEL_PREFIX = ""
_FUSED_PREFIX = ""
COMPILED_MODEL_FILE_NAME = "model.pt"
CHECKPOINT_DIR = "checkpoint"
def __init__(
self,
config: PretrainedConfig,
neuron_config: NxDNeuronConfig,
traced_model: torch.jit.ScriptModule,
model_wrappers: List[NxDModelWrapper],
):
self.config = copy.deepcopy(config)
self.neuron_config = copy.deepcopy(neuron_config)
# Override torch_dtype in config as it is used by the neuronx_distributed code to cast weights to the correct type
self.config.torch_dtype = self.neuron_config.torch_dtype
if neuron_config.flash_decoding_enabled:
# FIXME: this should not be part of neuron_config but is used in downstream classes
# Could it be deduced from tensor shapes ?
self.neuron_config.num_cores_per_group = calculate_num_cores_per_group(
config.num_attention_heads, config.num_key_value_heads, neuron_config.tp_degree
)
self._traced_model = traced_model
self.model_wrappers = model_wrappers # Required for loading weights
for model_wrapper in self.model_wrappers:
model_wrapper.model = self._traced_model
def forward(self, **kwargs):
"""Forward pass for this model."""
raise NotImplementedError("forward is not implemented")
@classmethod
def get_config_cls(cls) -> PretrainedConfig:
"""Gets the config class for this model."""
raise NotImplementedError("get_config_cls is not implemented")
@classmethod
def get_neuron_config_cls(cls) -> NxDNeuronConfig:
raise NotImplementedError("get_neuron_config_cls is not implemented")
@classmethod
def get_compiler_args(cls, neuron_config) -> str:
"""Gets the Neuron compiler arguments to use when compiling this model."""
return None
@staticmethod
def compile(neuron_config, model_wrappers: List[NxDModelWrapper], compiler_args: str, debug: bool = False):
builder = get_builder(neuron_config, model_wrappers, debug=debug, compiler_args=compiler_args)
with neff_cache():
return builder.trace(initialize_model_weights=False)
def save(self, dest_path, weight_path: Optional[str] = None):
if self._traced_model is None:
raise ValueError("Model has not been compiled or loaded")
dest_path = normalize_path(dest_path)
self.config.save_pretrained(dest_path)
self.neuron_config.save_pretrained(dest_path)
torch.jit.save(self._traced_model, dest_path + self.COMPILED_MODEL_FILE_NAME)
if weight_path is not None:
self.shard_checkpoint(
src_path=weight_path,
dest_path=os.path.join(dest_path, self.CHECKPOINT_DIR),
)
def shard_checkpoint(self, src_path, dest_path, debug: bool = False):
shards_path = get_shards_path(dest_path)
checkpoint_loader = partial(self.checkpoint_loader_fn, src_path, self.config, self.neuron_config)
sharder = get_builder(
self.neuron_config,
self.model_wrappers,
debug=debug,
checkpoint_loader=checkpoint_loader,
compiler_args=self.get_compiler_args(self.neuron_config),
)
sharder.shard_checkpoint(serialize_path=shards_path)
if hlo_utils.NXD_LAYOUT_TRANSFORMATION_OPTIONS in os.environ:
sharder.transform_weight_layout_with_overriden_option(sharded_checkpoint_dir=shards_path)
def _load_weights_from_path(self, weights_path):
weights_path = normalize_path(weights_path)
"""Loads the model weights to the Neuron device."""
if self._traced_model is None:
raise ValueError("Model is not loaded")
start_rank_id = self.neuron_config.start_rank_id
local_ranks_size = self.neuron_config.local_ranks_size
logging.info(f"loading models for ranks {start_rank_id}...{start_rank_id + local_ranks_size - 1}")
weights = []
shards_path = get_shards_path(weights_path)
def get_shard_name(rank):
return os.path.join(shards_path, f"tp{rank}_sharded_checkpoint.safetensors")
if os.path.exists(get_shard_name(start_rank_id)):
# If sharded checkpoints exist, load them
logger.info(f"Loading sharded checkpoint from {shards_path}")
for rank in range(start_rank_id, start_rank_id + local_ranks_size):
ckpt = load_file(get_shard_name(rank))
weights.append(ckpt)
else:
logger.info("There are no saved sharded checkpoints.")
checkpoint_loader = partial(self.checkpoint_loader_fn, weights_path, self.config, self.neuron_config)
sharder = get_builder(
self.neuron_config,
self.model_wrappers,
debug=False,
checkpoint_loader=checkpoint_loader,
compiler_args=self.get_compiler_args(self.neuron_config),
)
source_model_key = list(sharder.model_collection.keys())[0]
for rank in range(start_rank_id, start_rank_id + local_ranks_size):
logger.info(f"Sharding and loading rank {rank}")
ckpt = sharder.shard_weights(rank, sharder.model_collection[source_model_key])
weights.append(ckpt)
start_rank_tensor = torch.tensor([start_rank_id], dtype=torch.int32, device="cpu")
self._traced_model.nxd_model.initialize(weights, start_rank_tensor)
def load_weights(
self,
model_name_or_path: Union[str, Path],
token: Optional[Union[bool, str]] = None,
cache_dir: Optional[str] = None,
force_download: Optional[bool] = False,
local_files_only: Optional[bool] = False,
) -> None:
"""Loads the model weights from the given path."""
if os.path.exists(model_name_or_path):
# Look first for pre-sharded weights
checkpoint_path = os.path.join(model_name_or_path, self.CHECKPOINT_DIR)
if os.path.exists(checkpoint_path):
self._load_weights_from_path(checkpoint_path)
return
# Fall-back to standard model weights, if any
try:
self._load_weights_from_path(model_name_or_path)
return
except FileNotFoundError:
logger.info(f"Checkpoint file not found in {model_name_or_path}, trying to load from HuggingFace Hub.")
if self.neuron_config.checkpoint_id is not None:
# Fetch weights from the checkpoint
checkpoint_dir = TemporaryDirectory()
os.chmod(checkpoint_dir.name, 0o775)
snapshot_download(
repo_id=self.neuron_config.checkpoint_id,
revision=self.neuron_config.checkpoint_revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only,
token=token,
local_dir=checkpoint_dir.name,
allow_patterns=["*.safetensors*"],
)
self._load_weights_from_path(checkpoint_dir.name)
checkpoint_dir.cleanup()
else:
raise ValueError(f"Checkpoint file not found under {model_name_or_path}.")
def checkpoint_loader_fn(self, checkpoint_path, config, neuron_config):
"""This function loads the model's state dictionary and weights from the hf model"""
model_sd = self.get_state_dict(checkpoint_path, config, neuron_config)
if neuron_config.torch_dtype != torch.float32:
for name, param in model_sd.items():
if torch.is_floating_point(param) and param.dtype is not neuron_config.torch_dtype:
logger.debug(f"Converting {name} to {neuron_config.torch_dtype}")
model_sd[name] = param.to(neuron_config.torch_dtype)
return model_sd
@classmethod
def get_state_dict(cls, model_path: str, config: PretrainedConfig, neuron_config: NxDNeuronConfig) -> dict:
"""Gets the state dict for this model."""
model_sd = load_state_dict(model_path)
param_name_list = list(model_sd.keys())
for param_name in param_name_list:
if param_name.startswith(cls._STATE_DICT_MODEL_PREFIX):
updated_param_name = param_name.replace(
cls._STATE_DICT_MODEL_PREFIX, cls._NEW_STATE_DICT_MODEL_PREFIX, 1
)
model_sd[updated_param_name] = model_sd[param_name]
del model_sd[param_name]
model_sd = cls.convert_hf_to_neuron_state_dict(model_sd, config, neuron_config)
if getattr(config, "tie_word_embeddings", False):
cls.update_state_dict_for_tied_weights(model_sd)
param_name_list = list(model_sd.keys())
if cls._FUSED_PREFIX != "":
for param_name in param_name_list:
model_sd[f"{cls._FUSED_PREFIX}.{param_name}"] = model_sd[param_name]
del model_sd[param_name]
return model_sd
@staticmethod
def convert_hf_to_neuron_state_dict(state_dict: dict, config: PretrainedConfig) -> dict:
"""This function should be over-ridden in child classes as needed"""
return state_dict
@staticmethod
def load_hf_model(model_path):
"""Loads the HuggingFace model from the given checkpoint path."""
return AutoModelForCausalLM.from_pretrained(model_path)
@staticmethod
def update_state_dict_for_tied_weights(state_dict):
"""Implement state_dict update for each model class with tied weights"""
raise NotImplementedError("State-dict update not implemented")
@property
def device(self) -> torch.device:
"""
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
device).
"""
# We dont want HF to move parameters to device
return torch.device("cpu")
def reset(self):
"""Resets the model state. Can be implemented by subclasses."""
pass