in src/accelerate/accelerator.py [0:0]
def _prepare_te(self, *args):
if not is_transformer_engine_available():
raise ImportError(
"`transformer_engine` was not found on your system. Please ensure that `transformer_engine` is installed"
)
model, optimizer = None, None
num_models, num_optimizers = 0, 0
result = [obj for obj in args]
for obj in result:
if isinstance(obj, torch.nn.Module):
model = obj
num_models += 1
elif isinstance(obj, (torch.optim.Optimizer)):
optimizer = obj
num_optimizers += 1
if optimizer is None and model is None:
return result
elif optimizer is None or model is None:
raise ValueError(
"You must pass a model and an optimizer together to `accelerate.prepare()` when using TransformerEngine."
)
elif num_models > 1 or num_optimizers > 1:
raise ValueError(
f"You can't use multiple models ({num_models}) or optimizers {num_optimizers} with TransformerEngine."
)
old_named_params = self._get_named_parameters(model)
with torch.no_grad():
convert_model(model)
new_named_params = self._get_named_parameters(model)
mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
# We need to switch the optimizer params to the new params *after* the model is wrapped in FSDP
for param_group in optimizer.param_groups:
param_group["params"] = [mapping[p] for p in param_group["params"]]
return result