def optimizer_state()

in src/hyperpod_nemo_adapter/collections/parts/fsdp_strategy.py [0:0]


    def optimizer_state(self, optimizer: torch.optim.Optimizer) -> Dict[str, torch.Tensor]:
        """
        Store the optimizer state dict in one of full, sharded or local format.
        """
        assert isinstance(self.checkpoint_io, SageMakerCheckpointIO)
        typ = self.checkpoint_io.checkpoint_type
        if typ == SageMakerCheckpointType.LOCAL:
            return self.local_optimizer_state_dict(optimizer)
        if typ == SageMakerCheckpointType.SHARDED or typ == SageMakerCheckpointType.PEFT_SHARDED:
            return self.sharded_optimizer_state_dict(optimizer)
        if typ == SageMakerCheckpointType.FULL:
            return self.full_optimizer_state_dict(optimizer)
        if typ == SageMakerCheckpointType.PEFT_FULL:
            return self.full_optimizer_state_dict(optimizer)
        raise NotImplementedError(f"Checkpoint type '{typ}' not implemented")