in optimum/habana/transformers/trainer.py [0:0]
def create_optimizer(self):
"""
Setup the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
Trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
if self.optimizer is None:
decay_parameters = self.get_decay_parameter_names(self.model)
optimizer_grouped_parameters = []
for t_params, t_weight_decay in zip(
[
[p for n, p in self.model.named_parameters() if n in decay_parameters and p.requires_grad],
[p for n, p in self.model.named_parameters() if n not in decay_parameters and p.requires_grad],
],
[self.args.weight_decay, 0.0],
):
# Empty groups of parameters are filtered because they make FusedAdamW crash
if t_params:
optimizer_grouped_parameters.append(
{
"params": t_params,
"weight_decay": t_weight_decay,
}
)
if self.gaudi_config.use_fused_adam and self.args.use_habana:
try:
from habana_frameworks.torch.hpex.optimizers import FusedAdamW
except ImportError as error:
error.msg = (
f"Could not import 'FusedAdamW' from 'habana_frameworks.torch.hpex.optimizers'. {error.msg}."
)
raise error
optimizer_cls = FusedAdamW
optimizer_kwargs = {
"lr": self.args.learning_rate,
"betas": (self.args.adam_beta1, self.args.adam_beta2),
"eps": self.args.adam_epsilon,
}
elif self.optimizer_cls_and_kwargs is not None:
optimizer_cls, optimizer_kwargs = self.optimizer_cls_and_kwargs
else:
optimizer_cls, optimizer_kwargs = self.get_optimizer_cls_and_kwargs(self.args, self.model)
# Overwrite `params` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for GaLore optimizer.
if "params" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("params")
# Overwrite `model` in case it's created by `get_optimizer_cls_and_kwargs`
# e.g. for LOMO optimizer.
if "model" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("model")
# For layer-wise dummy optimizers we overwrite optimizer_grouped_parameters with `optimizer_dict`
# to avoid arguments conflicts.
if "optimizer_dict" in optimizer_kwargs:
optimizer_grouped_parameters = optimizer_kwargs.pop("optimizer_dict")
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
return self.optimizer