in src/accelerate/utils/dataclasses.py [0:0]
def set_state_dict_type(self, state_dict_type=None):
"""
Set the state dict config based on the `StateDictType`.
"""
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullOptimStateDictConfig,
FullStateDictConfig,
ShardedOptimStateDictConfig,
ShardedStateDictConfig,
StateDictType,
)
# Override the state_dict_type if provided, typical use case:
# user trains with sharded, but final save is with full
if state_dict_type is not None:
self.state_dict_type = state_dict_type
if self.state_dict_type is None:
self.state_dict_type = os.environ.get(
"FSDP_STATE_DICT_TYPE", "FULL_STATE_DICT" if self.fsdp_version == 1 else "SHARDED_STATE_DICT"
)
if isinstance(self.state_dict_type, str):
if self.state_dict_type.isdigit():
self.state_dict_type = StateDictType(int(self.state_dict_type))
else:
self.state_dict_type = StateDictType[self.state_dict_type.upper()]
if self.state_dict_type == StateDictType.FULL_STATE_DICT:
if self.state_dict_config is None:
self.state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
if self.optim_state_dict_config is None:
self.optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True)
elif self.state_dict_type == StateDictType.SHARDED_STATE_DICT:
if self.state_dict_config is None:
self.state_dict_config = ShardedStateDictConfig(offload_to_cpu=True)
if self.optim_state_dict_config is None:
self.optim_state_dict_config = ShardedOptimStateDictConfig(offload_to_cpu=True)
if self.fsdp_version == 2 and self.state_dict_type == StateDictType.LOCAL_STATE_DICT:
raise ValueError(
"FSDP2 does not support LOCAL_STATE_DICT. "
"Please set `fsdp_state_dict_type` to `SHARDED_STATE_DICT` or `FULL_STATE_DICT`."
)