in src/accelerate/accelerator.py [0:0]
def _prepare_msamp(self, *args, device_placement):
if not is_msamp_available():
raise ImportError(
"MS-AMP was not found on your system. Please ensure that MS-AMP is available "
" or choose `'te'` as the backend for FP8 mixed precision training."
)
# We've already checked for FSDP + MS-AMP during `__init__`
import msamp
model, optimizer = None, None
optimizer_index = None
num_models, num_optimizers = 0, 0
result = [obj for obj in args]
for i, obj in enumerate(result):
if isinstance(obj, torch.nn.Module):
model = obj
num_models += 1
elif isinstance(obj, (torch.optim.Optimizer)):
optimizer = obj
optimizer_index = i
num_optimizers += 1
# DataLoader/Scheduler case
if optimizer is None and model is None:
return result, device_placement
elif optimizer is None or model is None:
raise ValueError(
"You must pass a model and an optimizer together to `accelerate.prepare()` when using MS-AMP."
)
elif num_models > 1 or num_optimizers > 1:
raise ValueError(
f"You can't use multiple models ({num_models}) or optimizers {num_optimizers} with MS-AMP."
)
else:
# DEPRECATE @ 2.0
if self.fp8_recipe_handler is not None:
opt_level = self.fp8_recipe_handler.opt_level
else:
opt_level = self.msamp_recipe_handler.opt_level
model, optimizer = msamp.initialize(model, optimizer, opt_level=opt_level)
for i in range(len(result)):
if isinstance(result[i], torch.nn.Module):
result[i] = model
elif isinstance(result[i], (torch.optim.Optimizer)):
result[i] = optimizer
if optimizer_index is not None:
# NOTE: MS-AMP moves the optimizer, but *not* the model to the right device
device_placement[optimizer_index] = False
return tuple(result), device_placement