in diffq/lsq.py [0:0]
def setup_optimizer(self, optimizer: torch.optim.Optimizer, **kwargs):
"""
Setup the optimizer to tune the scale parameter.
Following [Esser et al. 2019], we use the same LR and weight decay
as the base optimizer, unless specified otherwise.
Args:
optimizer (torch.Optimizer): optimizer to use.
kwargs (dict): overrides for optimization parameters
"""
assert not self._optimizer_setup
self._optimizer_setup = True
params = [qp.scale for qp in self._qparams]
for group in optimizer.param_groups:
for q in list(group["params"]):
for p in params:
if p is q:
raise RuntimeError("You should create the optimizer "
"before the quantizer!")
group = {"params": params}
group.update(kwargs)
optimizer.add_param_group(group)