in src/hyperpod_nemo_adapter/collections/parts/fsdp_strategy.py [0:0]
def load_sharded_optim_state_dict(self, trainer, checkpoint, path):
typ = self.checkpoint_io.checkpoint_type
# For PEFT_SHARDED, the checkpoint does not contain the model state_dict
# Use the sharded_model_state_dict as the checkpoint adapter weights will have been loaded in at this point
if typ == SageMakerCheckpointType.PEFT_SHARDED:
checkpoint_state_dict = self.sharded_model_state_dict
else:
checkpoint_state_dict = checkpoint["state_dict"]
for i, optimizer in enumerate(trainer.optimizers):
optimizer_key = f"{OPTIMIZER_KEY_PREFIX}_{i}"
state_dict = load_sharded_optimizer_state_dict(
model_state_dict=checkpoint_state_dict,
optimizer_key=optimizer_key,
storage_reader=DistributedFileSystemReader(path),
process_group=self.pytorch_model.process_group,
)
flattened_osd = FSDP.optim_state_dict_to_load(
model=self.pytorch_model, optim=optimizer, optim_state_dict=state_dict[optimizer_key]
)
optimizer.load_state_dict(flattened_osd)