def __post_init__()

in src/accelerate/utils/dataclasses.py [0:0]


    def __post_init__(self):
        from torch.distributed.fsdp import (
            BackwardPrefetch,
            ShardingStrategy,
        )

        _fsdp2_warnings = set()

        env_prefix = "FSDP_"
        # Strategy: By default we should always assume that values are passed in, else we check the environment variables
        if self.fsdp_version is None:
            self.fsdp_version = int(os.environ.get(env_prefix + "VERSION", "1"))

        if self.fsdp_version == 2:
            if not is_torch_version(">=", FSDP2_PYTORCH_VERSION):
                raise ImportError(f"FSDP2 requires PyTorch >= {FSDP2_PYTORCH_VERSION}")

        if self.sharding_strategy is not None:
            # We cannot properly detect all of the cases, as by default `args.fsdp_sharding_strategy` is set to `fully_shard`
            # Therefore we issue a warning only if the user has explicitly set it inside their plugin
            _fsdp2_warnings.add(
                "sharding_strategy is deprecated in favor of reshard_after_forward. "
                "This will be removed in a future version of Accelerate."
            )
        if self.fsdp_version == 1:
            if self.sharding_strategy is None:
                self.sharding_strategy = os.environ.get(env_prefix + "SHARDING_STRATEGY", "FULL_SHARD")
            if isinstance(self.sharding_strategy, str):
                if self.sharding_strategy.upper() in FSDP_SHARDING_STRATEGY:
                    self.sharding_strategy = FSDP_SHARDING_STRATEGY.index(self.sharding_strategy.upper()) + 1
                if isinstance(self.sharding_strategy, int) or self.sharding_strategy.isdigit():
                    self.sharding_strategy = ShardingStrategy(int(self.sharding_strategy))
                else:
                    self.sharding_strategy = ShardingStrategy[self.sharding_strategy.upper()]

        # Fallback to `reshard_after_forward` in FSDP1 if `sharding_strategy` is not set
        if self.reshard_after_forward is None and self.sharding_strategy is None:
            reshard_after_forward = os.environ.get(
                env_prefix + "RESHARD_AFTER_FORWARD", "true" if self.fsdp_version == 2 else "FULL_SHARD"
            )
            if self.fsdp_version == 2:
                self.reshard_after_forward = str_to_bool(reshard_after_forward.lower(), to_bool=True)
            else:
                self.reshard_after_forward = reshard_after_forward
        if isinstance(self.reshard_after_forward, str):
            if self.fsdp_version == 2:
                self.reshard_after_forward = str_to_bool(self.reshard_after_forward.lower(), to_bool=True)
            else:
                # We need to remap based on custom enum values for user readability
                if self.reshard_after_forward.upper() in FSDP_SHARDING_STRATEGY:
                    self.reshard_after_forward = FSDP_SHARDING_STRATEGY.index(self.reshard_after_forward.upper()) + 1
                if isinstance(self.reshard_after_forward, int) or self.reshard_after_forward.isdigit():
                    self.reshard_after_forward = ShardingStrategy(int(self.reshard_after_forward))
                else:
                    self.reshard_after_forward = ShardingStrategy[self.reshard_after_forward.upper()]

        if self.fsdp_version == 2 and not isinstance(self.reshard_after_forward, bool):
            raise ValueError(
                f"reshard_after_forward set to {self.reshard_after_forward}. This is not supported with FSDP2, please set to a `bool`"
            )
        if self.fsdp_version == 1 and isinstance(self.reshard_after_forward, bool):
            raise ValueError(
                f"reshard_after_forward set to {self.reshard_after_forward}. This is not supported with FSDP1, please set to a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`"
            )

        if self.cpu_offload is None:
            self.cpu_offload = str_to_bool(os.environ.get(env_prefix + "OFFLOAD_PARAMS", "False")) == 1

        self.set_cpu_offload()  # abstracted away to hide imports due to version checks
        self.validate_cpu_offload()

        if self.backward_prefetch is None:
            self.backward_prefetch = os.environ.get(env_prefix + "BACKWARD_PREFETCH", None)
        if isinstance(self.backward_prefetch, str) and self.backward_prefetch.upper() == "NO_PREFETCH":
            self.backward_prefetch = None
        if self.backward_prefetch is not None and not isinstance(self.backward_prefetch, BackwardPrefetch):
            if isinstance(self.backward_prefetch, str) and self.backward_prefetch.upper() in FSDP_BACKWARD_PREFETCH:
                self.backward_prefetch = FSDP_BACKWARD_PREFETCH.index(self.backward_prefetch.upper()) + 1
            if isinstance(self.backward_prefetch, int) or self.backward_prefetch.isdigit():
                self.backward_prefetch = BackwardPrefetch(int(self.backward_prefetch))
            else:
                self.backward_prefetch = BackwardPrefetch[self.backward_prefetch.upper()]
        if self.fsdp_version == 2 and self.backward_prefetch is not None:
            _fsdp2_warnings.add("backward_prefetch is not supported in FSDP2. Setting backward prefetch to None.")
            self.backward_prefetch = None

        self.set_state_dict_type()

        if self.auto_wrap_policy is None:
            self.auto_wrap_policy = os.environ.get(env_prefix + "AUTO_WRAP_POLICY", "NO_WRAP")
        if isinstance(self.auto_wrap_policy, str):
            if self.auto_wrap_policy.upper() not in FSDP_AUTO_WRAP_POLICY:
                raise ValueError(
                    f"Invalid auto wrap policy: {self.auto_wrap_policy}. Must be one of {FSDP_AUTO_WRAP_POLICY}"
                )
            from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy

            if self.auto_wrap_policy.upper() == "TRANSFORMER_BASED_WRAP":
                self.auto_wrap_policy = transformer_auto_wrap_policy
                if self.transformer_cls_names_to_wrap is None:
                    self.transformer_cls_names_to_wrap = os.environ.get(env_prefix + "TRANSFORMER_CLS_TO_WRAP", None)
                if isinstance(self.transformer_cls_names_to_wrap, str):
                    self.transformer_cls_names_to_wrap = self.transformer_cls_names_to_wrap.split(",")
            elif self.auto_wrap_policy.upper() == "SIZE_BASED_WRAP":
                self.auto_wrap_policy = size_based_auto_wrap_policy
                if self.min_num_params is None:
                    self.min_num_params = int(os.environ.get(env_prefix + "MIN_NUM_PARAMS", 0))
                elif not isinstance(self.min_num_params, int):
                    raise ValueError(
                        f"`min_num_params` must be an integer. Got {self.min_num_params} of type {type(self.min_num_params)}"
                    )
            elif self.auto_wrap_policy.upper() == "NO_WRAP":
                self.auto_wrap_policy = None

        if self.use_orig_params is None and self.fsdp_version == 1:
            self.use_orig_params = str_to_bool(os.environ.get(env_prefix + "USE_ORIG_PARAMS", "False")) == 1
        if self.fsdp_version == 2 and self.use_orig_params is not None:
            _fsdp2_warnings.add("use_orig_params is obsolete in FSDP2, as FSDP2 always uses the original parameters.")
            self.use_orig_params = None

        if self.sync_module_states is None and self.fsdp_version == 1:
            self.sync_module_states = str_to_bool(os.environ.get(env_prefix + "SYNC_MODULE_STATES", "False")) == 1
        if self.fsdp_version == 2 and self.sync_module_states is not None:
            _fsdp2_warnings.add(
                "sync_module_states is obsolete in FSDP2, as it is not needed anymore."
                "Setting sync_module_states to None."
            )
            self.sync_module_states = None

        if self.forward_prefetch is None and self.fsdp_version == 1:
            self.forward_prefetch = str_to_bool(os.environ.get(env_prefix + "FORWARD_PREFETCH", "False")) == 1
        if self.fsdp_version == 2 and self.forward_prefetch is not None:
            raise ValueError("forward_prefetch is not yet implemented in FSDP2, set to None or use `fsdp_version=1`")

        if self.activation_checkpointing is None:
            self.activation_checkpointing = (
                str_to_bool(os.environ.get(env_prefix + "ACTIVATION_CHECKPOINTING", "False")) == 1
            )

        if self.cpu_ram_efficient_loading is None:
            self.cpu_ram_efficient_loading = (
                str_to_bool(os.environ.get(env_prefix + "CPU_RAM_EFFICIENT_LOADING", "False")) == 1
            )
        # There's no need to specify sync_module_states in FSDP2
        if self.fsdp_version == 1 and self.cpu_ram_efficient_loading and not self.sync_module_states:
            warnings.warn(
                "sync_module_states cannot be False since efficient cpu ram loading enabled. "
                "Setting sync_module_states to True."
            )
            self.sync_module_states = True

        if self.cpu_ram_efficient_loading != bool(
            str_to_bool(os.environ.get(env_prefix + "CPU_RAM_EFFICIENT_LOADING", "False"))
        ):
            env_var = env_prefix + "CPU_RAM_EFFICIENT_LOADING"
            warnings.warn(
                f"The `cpu_ram_efficient_loading` flag for `FullyShardedDataParallelPlugin` does not match the environment variable {env_var}. "
                "Setting environment variable to match `cpu_ram_efficient_loading`."
            )
            os.environ[env_var] = str(self.cpu_ram_efficient_loading)

        if isinstance(self.mixed_precision_policy, dict):
            self.set_mixed_precision(self.mixed_precision_policy)
        if self.mixed_precision_policy is not None:
            self.validate_mixed_precision_policy()

        if self.sync_module_states:
            if is_npu_available():
                device = torch.npu.current_device()
            elif is_mlu_available():
                device = torch.mlu.current_device()
            elif is_musa_available():
                device = torch.musa.current_device()
            elif is_cuda_available():
                device = torch.cuda.current_device()
            elif is_xpu_available():
                device = torch.xpu.current_device()
            elif is_hpu_available():
                device = torch.hpu.current_device()
            else:
                raise RuntimeError(
                    "There are currently no available devices found, must be one of 'XPU', 'CUDA', 'MLU', 'NPU', 'MUSA', or 'HPU'."
                )
            # Create a function that will be used to initialize the parameters of the model
            # when using `sync_module_states`
            self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False)

        #  Single warning for all deprecation warnings due to FSDP2 conversion
        if _fsdp2_warnings:
            logger.warning("Multiple deprecation warnings due to FSDP2 conversion:\n".join(_fsdp2_warnings))