in domainbed/algorithms.py [0:0]
def update(self, minibatches, unlabeled=None):
"""
Terms being computed:
* Li = Loss(xi, yi, params)
* Gi = Grad(Li, params)
* Lj = Loss(xj, yj, Optimizer(params, grad(Li, params)))
* Gj = Grad(Lj, params)
* params = Optimizer(params, Grad(Li + beta * Lj, params))
* = Optimizer(params, Gi + beta * Gj)
That is, when calling .step(), we want grads to be Gi + beta * Gj
For computational efficiency, we do not compute second derivatives.
"""
num_mb = len(minibatches)
objective = 0
self.optimizer.zero_grad()
for p in self.network.parameters():
if p.grad is None:
p.grad = torch.zeros_like(p)
for (xi, yi), (xj, yj) in random_pairs_of_minibatches(minibatches):
# fine tune clone-network on task "i"
inner_net = copy.deepcopy(self.network)
inner_opt = torch.optim.Adam(
inner_net.parameters(),
lr=self.hparams["lr"],
weight_decay=self.hparams['weight_decay']
)
inner_obj = F.cross_entropy(inner_net(xi), yi)
inner_opt.zero_grad()
inner_obj.backward()
inner_opt.step()
# The network has now accumulated gradients Gi
# The clone-network has now parameters P - lr * Gi
for p_tgt, p_src in zip(self.network.parameters(),
inner_net.parameters()):
if p_src.grad is not None:
p_tgt.grad.data.add_(p_src.grad.data / num_mb)
# `objective` is populated for reporting purposes
objective += inner_obj.item()
# this computes Gj on the clone-network
loss_inner_j = F.cross_entropy(inner_net(xj), yj)
grad_inner_j = autograd.grad(loss_inner_j, inner_net.parameters(),
allow_unused=True)
# `objective` is populated for reporting purposes
objective += (self.hparams['mldg_beta'] * loss_inner_j).item()
for p, g_j in zip(self.network.parameters(), grad_inner_j):
if g_j is not None:
p.grad.data.add_(
self.hparams['mldg_beta'] * g_j.data / num_mb)
# The network has now accumulated gradients Gi + beta * Gj
# Repeat for all train-test splits, do .step()
objective /= len(minibatches)
self.optimizer.step()
return {'loss': objective}