def clear_optimizer()

in diffq/diffq.py [0:0]


    def clear_optimizer(self, optimizer: torch.optim.Optimizer):
        params = [qp.logit for qp in self._qparams]

        for group in optimizer.param_groups:
            new_params = []
            for q in list(group["params"]):
                matched = False
                for p in params:
                    if p is q:
                        matched = True
                if not matched:
                    new_params.append(q)
            group["params"][:] = new_params