in src/accelerate/accelerator.py [0:0]
def load_state(self, input_dir: str = None, load_kwargs: dict | None = None, **load_model_func_kwargs):
"""
Loads the current states of the model, optimizer, scaler, RNG generators, and registered objects.
<Tip>
Should only be used in conjunction with [`Accelerator.save_state`]. If a file is not registered for
checkpointing, it will not be loaded if stored in the directory.
</Tip>
Args:
input_dir (`str` or `os.PathLike`):
The name of the folder all relevant weights and states were saved in. Can be `None` if
`automatic_checkpoint_naming` is used, and will pick up from the latest checkpoint.
load_kwargs (`dict`, *optional*):
Additional keyword arguments for the underlying `load` function, such as optional arguments for
state_dict and optimizer on.
load_model_func_kwargs (`dict`, *optional*):
Additional keyword arguments for loading model which can be passed to the underlying load function,
such as optional arguments for DeepSpeed's `load_checkpoint` function or a `map_location` to load the
model and optimizer on.
Example:
```python
>>> from accelerate import Accelerator
>>> accelerator = Accelerator()
>>> model, optimizer, lr_scheduler = ...
>>> model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
>>> accelerator.load_state("my_checkpoint")
```
"""
if input_dir is not None:
# Check if folder exists
input_dir = os.path.expanduser(input_dir)
if not os.path.isdir(input_dir):
raise ValueError(f"Tried to find {input_dir} but folder does not exist")
elif self.project_configuration.automatic_checkpoint_naming:
# Pick up from automatic checkpoint naming
input_dir = os.path.join(self.project_dir, "checkpoints")
folders = [os.path.join(input_dir, folder) for folder in os.listdir(input_dir)]
def _inner(folder):
return list(map(int, re.findall(r"[\/]?([0-9]+)(?=[^\/]*$)", folder)))[0]
folders.sort(key=_inner)
input_dir = folders[-1]
else:
raise ValueError("No input_dir provided and automatic checkpoint naming is disabled.")
logger.info(f"Loading states from {input_dir}")
# Load the models taking care of FSDP and DeepSpeed nuances
models = []
for i, model in enumerate(self._models):
if self.distributed_type == DistributedType.FSDP:
logger.info("Loading FSDP model")
load_fsdp_model(self.state.fsdp_plugin, self, model, input_dir, i)
logger.info(f"FSDP Model loaded from input dir {input_dir}")
elif self.distributed_type == DistributedType.DEEPSPEED:
logger.info("Loading DeepSpeed Model and Optimizer")
ckpt_id = f"{MODEL_NAME}" if i == 0 else f"{MODEL_NAME}_{i}"
model.load_checkpoint(input_dir, ckpt_id, **load_model_func_kwargs)
logger.info(f"DeepSpeed Model and Optimizer loaded from input dir {os.path.join(input_dir, ckpt_id)}")
elif self.distributed_type == DistributedType.MEGATRON_LM:
logger.info("Loading Megatron-LM Model, Optimizer and Scheduler")
model.load_checkpoint(input_dir)
logger.info(f"Megatron-LM Model , Optimizer and Scheduler loaded from input dir {input_dir}")
else:
models.append(model)
# Load the optimizers taking care of FSDP and DeepSpeed nuances
optimizers = []
if self.distributed_type == DistributedType.FSDP:
for i, opt in enumerate(self._optimizers):
logger.info("Loading FSDP Optimizer")
load_fsdp_optimizer(self.state.fsdp_plugin, self, opt, self._models[i], input_dir, i)
logger.info(f"FSDP Optimizer loaded from input dir {input_dir}")
elif self.distributed_type not in [DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]:
optimizers = self._optimizers
# Load the lr schedulers taking care of DeepSpeed nuances
schedulers = []
if self.distributed_type == DistributedType.DEEPSPEED:
for i, scheduler in enumerate(self._schedulers):
if isinstance(scheduler, DeepSpeedSchedulerWrapper):
continue
schedulers.append(scheduler)
elif self.distributed_type not in [DistributedType.MEGATRON_LM]:
schedulers = self._schedulers
dataloaders = self._dataloaders
# Call model loading hooks that might have been registered with
# accelerator.register_model_state_hook
for hook in self._load_model_state_pre_hook.values():
hook(models, input_dir)
map_location = load_model_func_kwargs.pop("map_location", None)
if map_location is None:
if self.num_processes > 1 and self.distributed_type in (
DistributedType.MULTI_GPU,
DistributedType.MULTI_MLU,
DistributedType.MULTI_SDAA,
DistributedType.MULTI_MUSA,
DistributedType.MULTI_NPU,
DistributedType.MULTI_HPU,
):
map_location = "on_device"
else:
map_location = "cpu"
override_attributes = load_accelerator_state(
input_dir,
models,
optimizers,
schedulers,
dataloaders,
self.state.process_index,
self.scaler,
map_location,
load_kwargs,
**load_model_func_kwargs,
)
if "step" in override_attributes:
self.step = override_attributes["step"]
custom_checkpoints = [
f for f in os.listdir(input_dir) if re.search(r"^custom_checkpoint_\d+\.pkl$", f) is not None
]
if len(custom_checkpoints) != len(self._custom_objects):
err = (
f"Number of custom checkpoints in folder {input_dir} does not match the number of registered objects:"
)
err += f"\n\tFound checkpoints: {len(custom_checkpoints)}"
err += f"\n\tRegistered objects: {len(self._custom_objects)}\n"
err += "Please make sure to only load checkpoints from folders that were created with the same set of registered objects,"
err += "or avoid using `custom_checkpoint` in the filename for files in that same directory and load them in manually."
raise RuntimeError(err)
else:
logger.info(f"Loading in {len(custom_checkpoints)} custom states")
for index, obj in enumerate(self._custom_objects):
load_custom_state(obj, input_dir, index)