in domainbed/lib/misc.py [0:0]
def accuracy(network, loader, weights, device):
correct = 0
total = 0
weights_offset = 0
network.eval()
with torch.no_grad():
for x, y in loader:
x = x.to(device)
y = y.to(device)
p = network.predict(x)
if weights is None:
batch_weights = torch.ones(len(x))
else:
batch_weights = weights[weights_offset : weights_offset + len(x)]
weights_offset += len(x)
batch_weights = batch_weights.to(device)
if p.size(1) == 1:
correct += (p.gt(0).eq(y).float() * batch_weights.view(-1, 1)).sum().item()
else:
correct += (p.argmax(1).eq(y).float() * batch_weights).sum().item()
total += batch_weights.sum().item()
network.train()
return correct / total