def lightning_module_state_dict()

in src/hyperpod_nemo_adapter/collections/parts/fsdp_strategy.py [0:0]


    def lightning_module_state_dict(self) -> Dict[str, Any]:
        """
        Store the model state dict in one of full, sharded or local format.
        """
        assert isinstance(self.checkpoint_io, SageMakerCheckpointIO)
        typ = self.checkpoint_io.checkpoint_type
        if typ == SageMakerCheckpointType.LOCAL:
            return self.local_model_state_dict
        if typ == SageMakerCheckpointType.SHARDED:
            return self.sharded_model_state_dict
        if typ == SageMakerCheckpointType.FULL:
            return self.full_model_state_dict
        if typ == SageMakerCheckpointType.PEFT_FULL:
            return self.full_model_state_dict
        # For PEFT_SHARDED, we do not need to store the model state_dict as the adapter weights
        # are stored separately
        if typ == SageMakerCheckpointType.PEFT_SHARDED:
            return None
        raise NotImplementedError(f"Checkpoint type '{typ}' not implemented")