def setup_optimizer()

in diffq/lsq.py [0:0]


    def setup_optimizer(self, optimizer: torch.optim.Optimizer, **kwargs):
        """
        Setup the optimizer to tune the scale parameter.
        Following [Esser et al. 2019], we use the same LR and weight decay
        as the base optimizer, unless specified otherwise.

        Args:
            optimizer (torch.Optimizer): optimizer to use.
            kwargs (dict): overrides for optimization parameters
        """
        assert not self._optimizer_setup
        self._optimizer_setup = True

        params = [qp.scale for qp in self._qparams]

        for group in optimizer.param_groups:
            for q in list(group["params"]):
                for p in params:
                    if p is q:
                        raise RuntimeError("You should create the optimizer "
                                           "before the quantizer!")

        group = {"params": params}
        group.update(kwargs)
        optimizer.add_param_group(group)