in scripts/imagenet/utils.py [0:0]
def process(input, target, all_reduce=None):
with torch.no_grad():
if ten_crops:
bs, ncrops, c, h, w = input.size()
input = input.view(-1, c, h, w)
target = target.cuda(non_blocking=True)
# compute output
if ten_crops:
output = model(input).view(bs, ncrops, -1).mean(1)
else:
output = model(input)
loss = criterion(output, target)
# measure accuracy and record loss
prec1, prec5 = accuracy_sum(output.data, target, topk=(1, 5))
loss *= target.shape[0]
count = target.new_tensor([target.shape[0]], dtype=torch.long)
if all_reduce:
all_reduce(count)
for meter, val in (losses, loss), (top1, prec1), (top5, prec5):
if all_reduce:
all_reduce(val)
val /= count.item()
meter.update(val.item(), count.item())