in lib/optim/extragradient.py [0:0]
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
if len(self.params_copy) == 0:
raise RuntimeError('Need to call extrapolation before calling step.')
loss = None
if closure is not None:
loss = closure()
i = -1
for group in self.param_groups:
for p in group['params']:
i += 1
u = self.update(p, group)
if u is None:
continue
# Update the parameters saved during the extrapolation step
p.data = self.params_copy[i].add_(u)
# Free the old parameters
self.params_copy = []
return loss