in optimum/graphcore/trainer.py [0:0]
def create_optimizer(self):
"""
Sets up the optimizer.
We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
trainer's init through `optimizers`, or subclass and override this method in a subclass.
"""
if self.optimizer is None:
decay_parameters = get_parameter_names(self.model, [nn.LayerNorm])
decay_parameters = {name for name in decay_parameters if "bias" not in name}
if self.args.lamb or self.args.lamb_no_bias_correction:
bias_parameters = {n for n, _ in self.model.named_parameters() if "bias" in n}
optimizer_grouped_parameters = [
{
"params": [
p for n, p in self.model.named_parameters() if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
# Disable LAMB updates for bias parameters
"params": [
p for n, p in self.model.named_parameters() if (n in bias_parameters and p.requires_grad)
],
"weight_decay": 0.0,
"max_weight_norm": 0.0,
},
{
"params": [
p
for n, p in self.model.named_parameters()
if n not in decay_parameters and n not in bias_parameters and p.requires_grad
],
"weight_decay": 0.0,
},
]
optimizer_cls = LAMB
optimizer_kwargs = {
"max_weight_norm": None,
"bias_correction": not self.args.lamb_no_bias_correction,
"eps": self.args.adam_epsilon,
}
else:
optimizer_grouped_parameters = [
{
"params": [
p for n, p in self.model.named_parameters() if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p
for n, p in self.model.named_parameters()
if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
optimizer_cls = AdamW
optimizer_kwargs = {
# TODO: disabled max_grad_norm because it make things fail, fix it.
# "max_grad_norm": self.args.max_grad_norm,
"betas": (self.args.adam_beta1, self.args.adam_beta2),
"eps": self.args.adam_epsilon,
"bias_correction": False,
}
first_order_type = torch.float32 if self.args.fp32 else torch.float16
optimizer_kwargs["lr"] = self.args.learning_rate
optimizer_kwargs["loss_scaling"] = self.args.loss_scaling
optimizer_kwargs["accum_type"] = first_order_type
optimizer_kwargs["first_order_momentum_accum_type"] = first_order_type
optimizer_kwargs["second_order_momentum_accum_type"] = torch.float32
self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
if self.args.lamb or self.args.lamb_no_bias_correction:
self.optimizer.variable_attrs.markAsConstant("max_weight_norm")
self.optimizer.variable_attrs.markAsConstant("weight_decay")
return self.optimizer