in domainbed/algorithms.py [0:0]
def neum(v, model, batch):
def hvp(y, w, v):
# First backprop
first_grads = autograd.grad(y, w, retain_graph=True, create_graph=True, allow_unused=True)
first_grads = torch.nn.utils.parameters_to_vector(first_grads)
# Elementwise products
elemwise_products = first_grads @ v
# Second backprop
return_grads = autograd.grad(elemwise_products, w, create_graph=True)
return_grads = torch.nn.utils.parameters_to_vector(return_grads)
return return_grads
v = v.detach()
h_estimate = v
cnt = 0.
model.eval()
iter = 10
for i in range(iter):
model.weight.grad *= 0
y = model(batch[0].detach())
loss = F.cross_entropy(y, batch[1].detach())
hv = hvp(loss, model.weight, v)
v -= hv
v = v.detach()
h_estimate = v + h_estimate
h_estimate = h_estimate.detach()
# not converge
if torch.max(abs(h_estimate)) > 10:
break
cnt += 1
model.train()
return h_estimate.detach()