in src/hyperpod_nemo_adapter/collections/parts/fsdp_strategy.py [0:0]
def optimizer_state(self, optimizer: torch.optim.Optimizer) -> Dict[str, torch.Tensor]:
"""
Store the optimizer state dict in one of full, sharded or local format.
"""
assert isinstance(self.checkpoint_io, SageMakerCheckpointIO)
typ = self.checkpoint_io.checkpoint_type
if typ == SageMakerCheckpointType.LOCAL:
return self.local_optimizer_state_dict(optimizer)
if typ == SageMakerCheckpointType.SHARDED or typ == SageMakerCheckpointType.PEFT_SHARDED:
return self.sharded_optimizer_state_dict(optimizer)
if typ == SageMakerCheckpointType.FULL:
return self.full_optimizer_state_dict(optimizer)
if typ == SageMakerCheckpointType.PEFT_FULL:
return self.full_optimizer_state_dict(optimizer)
raise NotImplementedError(f"Checkpoint type '{typ}' not implemented")