in src/hyperpod_nemo_adapter/collections/parts/fsdp_strategy.py [0:0]
def lightning_module_state_dict(self) -> Dict[str, Any]:
"""
Store the model state dict in one of full, sharded or local format.
"""
assert isinstance(self.checkpoint_io, SageMakerCheckpointIO)
typ = self.checkpoint_io.checkpoint_type
if typ == SageMakerCheckpointType.LOCAL:
return self.local_model_state_dict
if typ == SageMakerCheckpointType.SHARDED:
return self.sharded_model_state_dict
if typ == SageMakerCheckpointType.FULL:
return self.full_model_state_dict
if typ == SageMakerCheckpointType.PEFT_FULL:
return self.full_model_state_dict
# For PEFT_SHARDED, we do not need to store the model state_dict as the adapter weights
# are stored separately
if typ == SageMakerCheckpointType.PEFT_SHARDED:
return None
raise NotImplementedError(f"Checkpoint type '{typ}' not implemented")