in optimum/neuron/models/inference/backend/pretrained_model.py [0:0]
def _load_weights_from_path(self, weights_path):
weights_path = normalize_path(weights_path)
"""Loads the model weights to the Neuron device."""
if self._traced_model is None:
raise ValueError("Model is not loaded")
start_rank_id = self.neuron_config.start_rank_id
local_ranks_size = self.neuron_config.local_ranks_size
logging.info(f"loading models for ranks {start_rank_id}...{start_rank_id + local_ranks_size - 1}")
weights = []
shards_path = get_shards_path(weights_path)
def get_shard_name(rank):
return os.path.join(shards_path, f"tp{rank}_sharded_checkpoint.safetensors")
if os.path.exists(get_shard_name(start_rank_id)):
# If sharded checkpoints exist, load them
logger.info(f"Loading sharded checkpoint from {shards_path}")
for rank in range(start_rank_id, start_rank_id + local_ranks_size):
ckpt = load_file(get_shard_name(rank))
weights.append(ckpt)
else:
logger.info("There are no saved sharded checkpoints.")
checkpoint_loader = partial(self.checkpoint_loader_fn, weights_path, self.config, self.neuron_config)
sharder = get_builder(
self.neuron_config,
self.model_wrappers,
debug=False,
checkpoint_loader=checkpoint_loader,
compiler_args=self.get_compiler_args(self.neuron_config),
)
source_model_key = list(sharder.model_collection.keys())[0]
for rank in range(start_rank_id, start_rank_id + local_ranks_size):
logger.info(f"Sharding and loading rank {rank}")
ckpt = sharder.shard_weights(rank, sharder.model_collection[source_model_key])
weights.append(ckpt)
start_rank_tensor = torch.tensor([start_rank_id], dtype=torch.int32, device="cpu")
self._traced_model.nxd_model.initialize(weights, start_rank_tensor)