def create_optimizer()

in optimum/graphcore/trainer.py [0:0]


    def create_optimizer(self):
        """
        Sets up the optimizer.

        We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
        trainer's init through `optimizers`, or subclass and override this method in a subclass.
        """
        if self.optimizer is None:
            decay_parameters = get_parameter_names(self.model, [nn.LayerNorm])
            decay_parameters = {name for name in decay_parameters if "bias" not in name}
            if self.args.lamb or self.args.lamb_no_bias_correction:
                bias_parameters = {n for n, _ in self.model.named_parameters() if "bias" in n}
                optimizer_grouped_parameters = [
                    {
                        "params": [
                            p for n, p in self.model.named_parameters() if (n in decay_parameters and p.requires_grad)
                        ],
                        "weight_decay": self.args.weight_decay,
                    },
                    {
                        # Disable LAMB updates for bias parameters
                        "params": [
                            p for n, p in self.model.named_parameters() if (n in bias_parameters and p.requires_grad)
                        ],
                        "weight_decay": 0.0,
                        "max_weight_norm": 0.0,
                    },
                    {
                        "params": [
                            p
                            for n, p in self.model.named_parameters()
                            if n not in decay_parameters and n not in bias_parameters and p.requires_grad
                        ],
                        "weight_decay": 0.0,
                    },
                ]
                optimizer_cls = LAMB
                optimizer_kwargs = {
                    "max_weight_norm": None,
                    "bias_correction": not self.args.lamb_no_bias_correction,
                    "eps": self.args.adam_epsilon,
                }
            else:
                optimizer_grouped_parameters = [
                    {
                        "params": [
                            p for n, p in self.model.named_parameters() if (n in decay_parameters and p.requires_grad)
                        ],
                        "weight_decay": self.args.weight_decay,
                    },
                    {
                        "params": [
                            p
                            for n, p in self.model.named_parameters()
                            if (n not in decay_parameters and p.requires_grad)
                        ],
                        "weight_decay": 0.0,
                    },
                ]
                optimizer_cls = AdamW
                optimizer_kwargs = {
                    # TODO: disabled max_grad_norm because it make things fail, fix it.
                    #  "max_grad_norm": self.args.max_grad_norm,
                    "betas": (self.args.adam_beta1, self.args.adam_beta2),
                    "eps": self.args.adam_epsilon,
                    "bias_correction": False,
                }

            first_order_type = torch.float32 if self.args.fp32 else torch.float16
            optimizer_kwargs["lr"] = self.args.learning_rate
            optimizer_kwargs["loss_scaling"] = self.args.loss_scaling
            optimizer_kwargs["accum_type"] = first_order_type
            optimizer_kwargs["first_order_momentum_accum_type"] = first_order_type
            optimizer_kwargs["second_order_momentum_accum_type"] = torch.float32

            self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)

            if self.args.lamb or self.args.lamb_no_bias_correction:
                self.optimizer.variable_attrs.markAsConstant("max_weight_norm")

            self.optimizer.variable_attrs.markAsConstant("weight_decay")

        return self.optimizer