def set_state_dict_type()

in src/accelerate/utils/dataclasses.py [0:0]


    def set_state_dict_type(self, state_dict_type=None):
        """
        Set the state dict config based on the `StateDictType`.
        """
        from torch.distributed.fsdp.fully_sharded_data_parallel import (
            FullOptimStateDictConfig,
            FullStateDictConfig,
            ShardedOptimStateDictConfig,
            ShardedStateDictConfig,
            StateDictType,
        )

        # Override the state_dict_type if provided, typical use case:
        # user trains with sharded, but final save is with full
        if state_dict_type is not None:
            self.state_dict_type = state_dict_type

        if self.state_dict_type is None:
            self.state_dict_type = os.environ.get(
                "FSDP_STATE_DICT_TYPE", "FULL_STATE_DICT" if self.fsdp_version == 1 else "SHARDED_STATE_DICT"
            )
        if isinstance(self.state_dict_type, str):
            if self.state_dict_type.isdigit():
                self.state_dict_type = StateDictType(int(self.state_dict_type))
            else:
                self.state_dict_type = StateDictType[self.state_dict_type.upper()]

        if self.state_dict_type == StateDictType.FULL_STATE_DICT:
            if self.state_dict_config is None:
                self.state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
            if self.optim_state_dict_config is None:
                self.optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True)
        elif self.state_dict_type == StateDictType.SHARDED_STATE_DICT:
            if self.state_dict_config is None:
                self.state_dict_config = ShardedStateDictConfig(offload_to_cpu=True)
            if self.optim_state_dict_config is None:
                self.optim_state_dict_config = ShardedOptimStateDictConfig(offload_to_cpu=True)

        if self.fsdp_version == 2 and self.state_dict_type == StateDictType.LOCAL_STATE_DICT:
            raise ValueError(
                "FSDP2 does not support LOCAL_STATE_DICT. "
                "Please set `fsdp_state_dict_type` to `SHARDED_STATE_DICT` or `FULL_STATE_DICT`."
            )