def setup_optimizer()

in diffq/diffq.py [0:0]


    def setup_optimizer(self, optimizer: torch.optim.Optimizer,
                        lr: float = 1e-3, **kwargs):
        """
        Setup the optimizer to tune the number of bits. In particular, this will deactivate
        weight decay for the bits parameters.

        Args:
            optimizer (torch.Optimizer): optimizer to use.
            lr (float): specific learning rate for the bits parameters. 1e-3
                is perfect for Adam.,w
            kwargs (dict): overrides for other optimization parameters for the bits.
        """
        assert not self._optimizer_setup
        self._optimizer_setup = True

        params = [qp.logit for qp in self._qparams]

        for group in optimizer.param_groups:
            for q in list(group["params"]):
                for p in params:
                    if p is q:
                        raise RuntimeError("You should create the optimizer "
                                           "before the quantizer!")

        group = {"params": params, "lr": lr, "weight_decay": 0}
        group.update(kwargs)
        optimizer.add_param_group(group)