def load_sharded_optim_state_dict()

in src/hyperpod_nemo_adapter/collections/parts/fsdp_strategy.py [0:0]


    def load_sharded_optim_state_dict(self, trainer, checkpoint, path):
        typ = self.checkpoint_io.checkpoint_type
        # For PEFT_SHARDED, the checkpoint does not contain the model state_dict
        # Use the sharded_model_state_dict as the checkpoint adapter weights will have been loaded in at this point
        if typ == SageMakerCheckpointType.PEFT_SHARDED:
            checkpoint_state_dict = self.sharded_model_state_dict
        else:
            checkpoint_state_dict = checkpoint["state_dict"]
        for i, optimizer in enumerate(trainer.optimizers):
            optimizer_key = f"{OPTIMIZER_KEY_PREFIX}_{i}"
            state_dict = load_sharded_optimizer_state_dict(
                model_state_dict=checkpoint_state_dict,
                optimizer_key=optimizer_key,
                storage_reader=DistributedFileSystemReader(path),
                process_group=self.pytorch_model.process_group,
            )
            flattened_osd = FSDP.optim_state_dict_to_load(
                model=self.pytorch_model, optim=optimizer, optim_state_dict=state_dict[optimizer_key]
            )
            optimizer.load_state_dict(flattened_osd)