in src/accelerate/utils/dataclasses.py [0:0]
def __post_init__(self):
from torch.distributed.fsdp import (
BackwardPrefetch,
ShardingStrategy,
)
_fsdp2_warnings = set()
env_prefix = "FSDP_"
# Strategy: By default we should always assume that values are passed in, else we check the environment variables
if self.fsdp_version is None:
self.fsdp_version = int(os.environ.get(env_prefix + "VERSION", "1"))
if self.fsdp_version == 2:
if not is_torch_version(">=", FSDP2_PYTORCH_VERSION):
raise ImportError(f"FSDP2 requires PyTorch >= {FSDP2_PYTORCH_VERSION}")
if self.sharding_strategy is not None:
# We cannot properly detect all of the cases, as by default `args.fsdp_sharding_strategy` is set to `fully_shard`
# Therefore we issue a warning only if the user has explicitly set it inside their plugin
_fsdp2_warnings.add(
"sharding_strategy is deprecated in favor of reshard_after_forward. "
"This will be removed in a future version of Accelerate."
)
if self.fsdp_version == 1:
if self.sharding_strategy is None:
self.sharding_strategy = os.environ.get(env_prefix + "SHARDING_STRATEGY", "FULL_SHARD")
if isinstance(self.sharding_strategy, str):
if self.sharding_strategy.upper() in FSDP_SHARDING_STRATEGY:
self.sharding_strategy = FSDP_SHARDING_STRATEGY.index(self.sharding_strategy.upper()) + 1
if isinstance(self.sharding_strategy, int) or self.sharding_strategy.isdigit():
self.sharding_strategy = ShardingStrategy(int(self.sharding_strategy))
else:
self.sharding_strategy = ShardingStrategy[self.sharding_strategy.upper()]
# Fallback to `reshard_after_forward` in FSDP1 if `sharding_strategy` is not set
if self.reshard_after_forward is None and self.sharding_strategy is None:
reshard_after_forward = os.environ.get(
env_prefix + "RESHARD_AFTER_FORWARD", "true" if self.fsdp_version == 2 else "FULL_SHARD"
)
if self.fsdp_version == 2:
self.reshard_after_forward = str_to_bool(reshard_after_forward.lower(), to_bool=True)
else:
self.reshard_after_forward = reshard_after_forward
if isinstance(self.reshard_after_forward, str):
if self.fsdp_version == 2:
self.reshard_after_forward = str_to_bool(self.reshard_after_forward.lower(), to_bool=True)
else:
# We need to remap based on custom enum values for user readability
if self.reshard_after_forward.upper() in FSDP_SHARDING_STRATEGY:
self.reshard_after_forward = FSDP_SHARDING_STRATEGY.index(self.reshard_after_forward.upper()) + 1
if isinstance(self.reshard_after_forward, int) or self.reshard_after_forward.isdigit():
self.reshard_after_forward = ShardingStrategy(int(self.reshard_after_forward))
else:
self.reshard_after_forward = ShardingStrategy[self.reshard_after_forward.upper()]
if self.fsdp_version == 2 and not isinstance(self.reshard_after_forward, bool):
raise ValueError(
f"reshard_after_forward set to {self.reshard_after_forward}. This is not supported with FSDP2, please set to a `bool`"
)
if self.fsdp_version == 1 and isinstance(self.reshard_after_forward, bool):
raise ValueError(
f"reshard_after_forward set to {self.reshard_after_forward}. This is not supported with FSDP1, please set to a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`"
)
if self.cpu_offload is None:
self.cpu_offload = str_to_bool(os.environ.get(env_prefix + "OFFLOAD_PARAMS", "False")) == 1
self.set_cpu_offload() # abstracted away to hide imports due to version checks
self.validate_cpu_offload()
if self.backward_prefetch is None:
self.backward_prefetch = os.environ.get(env_prefix + "BACKWARD_PREFETCH", None)
if isinstance(self.backward_prefetch, str) and self.backward_prefetch.upper() == "NO_PREFETCH":
self.backward_prefetch = None
if self.backward_prefetch is not None and not isinstance(self.backward_prefetch, BackwardPrefetch):
if isinstance(self.backward_prefetch, str) and self.backward_prefetch.upper() in FSDP_BACKWARD_PREFETCH:
self.backward_prefetch = FSDP_BACKWARD_PREFETCH.index(self.backward_prefetch.upper()) + 1
if isinstance(self.backward_prefetch, int) or self.backward_prefetch.isdigit():
self.backward_prefetch = BackwardPrefetch(int(self.backward_prefetch))
else:
self.backward_prefetch = BackwardPrefetch[self.backward_prefetch.upper()]
if self.fsdp_version == 2 and self.backward_prefetch is not None:
_fsdp2_warnings.add("backward_prefetch is not supported in FSDP2. Setting backward prefetch to None.")
self.backward_prefetch = None
self.set_state_dict_type()
if self.auto_wrap_policy is None:
self.auto_wrap_policy = os.environ.get(env_prefix + "AUTO_WRAP_POLICY", "NO_WRAP")
if isinstance(self.auto_wrap_policy, str):
if self.auto_wrap_policy.upper() not in FSDP_AUTO_WRAP_POLICY:
raise ValueError(
f"Invalid auto wrap policy: {self.auto_wrap_policy}. Must be one of {FSDP_AUTO_WRAP_POLICY}"
)
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
if self.auto_wrap_policy.upper() == "TRANSFORMER_BASED_WRAP":
self.auto_wrap_policy = transformer_auto_wrap_policy
if self.transformer_cls_names_to_wrap is None:
self.transformer_cls_names_to_wrap = os.environ.get(env_prefix + "TRANSFORMER_CLS_TO_WRAP", None)
if isinstance(self.transformer_cls_names_to_wrap, str):
self.transformer_cls_names_to_wrap = self.transformer_cls_names_to_wrap.split(",")
elif self.auto_wrap_policy.upper() == "SIZE_BASED_WRAP":
self.auto_wrap_policy = size_based_auto_wrap_policy
if self.min_num_params is None:
self.min_num_params = int(os.environ.get(env_prefix + "MIN_NUM_PARAMS", 0))
elif not isinstance(self.min_num_params, int):
raise ValueError(
f"`min_num_params` must be an integer. Got {self.min_num_params} of type {type(self.min_num_params)}"
)
elif self.auto_wrap_policy.upper() == "NO_WRAP":
self.auto_wrap_policy = None
if self.use_orig_params is None and self.fsdp_version == 1:
self.use_orig_params = str_to_bool(os.environ.get(env_prefix + "USE_ORIG_PARAMS", "False")) == 1
if self.fsdp_version == 2 and self.use_orig_params is not None:
_fsdp2_warnings.add("use_orig_params is obsolete in FSDP2, as FSDP2 always uses the original parameters.")
self.use_orig_params = None
if self.sync_module_states is None and self.fsdp_version == 1:
self.sync_module_states = str_to_bool(os.environ.get(env_prefix + "SYNC_MODULE_STATES", "False")) == 1
if self.fsdp_version == 2 and self.sync_module_states is not None:
_fsdp2_warnings.add(
"sync_module_states is obsolete in FSDP2, as it is not needed anymore."
"Setting sync_module_states to None."
)
self.sync_module_states = None
if self.forward_prefetch is None and self.fsdp_version == 1:
self.forward_prefetch = str_to_bool(os.environ.get(env_prefix + "FORWARD_PREFETCH", "False")) == 1
if self.fsdp_version == 2 and self.forward_prefetch is not None:
raise ValueError("forward_prefetch is not yet implemented in FSDP2, set to None or use `fsdp_version=1`")
if self.activation_checkpointing is None:
self.activation_checkpointing = (
str_to_bool(os.environ.get(env_prefix + "ACTIVATION_CHECKPOINTING", "False")) == 1
)
if self.cpu_ram_efficient_loading is None:
self.cpu_ram_efficient_loading = (
str_to_bool(os.environ.get(env_prefix + "CPU_RAM_EFFICIENT_LOADING", "False")) == 1
)
# There's no need to specify sync_module_states in FSDP2
if self.fsdp_version == 1 and self.cpu_ram_efficient_loading and not self.sync_module_states:
warnings.warn(
"sync_module_states cannot be False since efficient cpu ram loading enabled. "
"Setting sync_module_states to True."
)
self.sync_module_states = True
if self.cpu_ram_efficient_loading != bool(
str_to_bool(os.environ.get(env_prefix + "CPU_RAM_EFFICIENT_LOADING", "False"))
):
env_var = env_prefix + "CPU_RAM_EFFICIENT_LOADING"
warnings.warn(
f"The `cpu_ram_efficient_loading` flag for `FullyShardedDataParallelPlugin` does not match the environment variable {env_var}. "
"Setting environment variable to match `cpu_ram_efficient_loading`."
)
os.environ[env_var] = str(self.cpu_ram_efficient_loading)
if isinstance(self.mixed_precision_policy, dict):
self.set_mixed_precision(self.mixed_precision_policy)
if self.mixed_precision_policy is not None:
self.validate_mixed_precision_policy()
if self.sync_module_states:
if is_npu_available():
device = torch.npu.current_device()
elif is_mlu_available():
device = torch.mlu.current_device()
elif is_musa_available():
device = torch.musa.current_device()
elif is_cuda_available():
device = torch.cuda.current_device()
elif is_xpu_available():
device = torch.xpu.current_device()
elif is_hpu_available():
device = torch.hpu.current_device()
else:
raise RuntimeError(
"There are currently no available devices found, must be one of 'XPU', 'CUDA', 'MLU', 'NPU', 'MUSA', or 'HPU'."
)
# Create a function that will be used to initialize the parameters of the model
# when using `sync_module_states`
self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False)
# Single warning for all deprecation warnings due to FSDP2 conversion
if _fsdp2_warnings:
logger.warning("Multiple deprecation warnings due to FSDP2 conversion:\n".join(_fsdp2_warnings))