def _load_weights_from_path()

in optimum/neuron/models/inference/backend/pretrained_model.py [0:0]


    def _load_weights_from_path(self, weights_path):
        weights_path = normalize_path(weights_path)

        """Loads the model weights to the Neuron device."""
        if self._traced_model is None:
            raise ValueError("Model is not loaded")

        start_rank_id = self.neuron_config.start_rank_id
        local_ranks_size = self.neuron_config.local_ranks_size

        logging.info(f"loading models for ranks {start_rank_id}...{start_rank_id + local_ranks_size - 1}")
        weights = []
        shards_path = get_shards_path(weights_path)

        def get_shard_name(rank):
            return os.path.join(shards_path, f"tp{rank}_sharded_checkpoint.safetensors")

        if os.path.exists(get_shard_name(start_rank_id)):
            # If sharded checkpoints exist, load them
            logger.info(f"Loading sharded checkpoint from {shards_path}")
            for rank in range(start_rank_id, start_rank_id + local_ranks_size):
                ckpt = load_file(get_shard_name(rank))
                weights.append(ckpt)
        else:
            logger.info("There are no saved sharded checkpoints.")
            checkpoint_loader = partial(self.checkpoint_loader_fn, weights_path, self.config, self.neuron_config)
            sharder = get_builder(
                self.neuron_config,
                self.model_wrappers,
                debug=False,
                checkpoint_loader=checkpoint_loader,
                compiler_args=self.get_compiler_args(self.neuron_config),
            )
            source_model_key = list(sharder.model_collection.keys())[0]
            for rank in range(start_rank_id, start_rank_id + local_ranks_size):
                logger.info(f"Sharding and loading rank {rank}")
                ckpt = sharder.shard_weights(rank, sharder.model_collection[source_model_key])
                weights.append(ckpt)
        start_rank_tensor = torch.tensor([start_rank_id], dtype=torch.int32, device="cpu")
        self._traced_model.nxd_model.initialize(weights, start_rank_tensor)