def process()

in scripts/imagenet/utils.py [0:0]


    def process(input, target, all_reduce=None):
        with torch.no_grad():
            if ten_crops:
                bs, ncrops, c, h, w = input.size()
                input = input.view(-1, c, h, w)

            target = target.cuda(non_blocking=True)

            # compute output
            if ten_crops:
                output = model(input).view(bs, ncrops, -1).mean(1)
            else:
                output = model(input)
            loss = criterion(output, target)

            # measure accuracy and record loss
            prec1, prec5 = accuracy_sum(output.data, target, topk=(1, 5))

            loss *= target.shape[0]
            count = target.new_tensor([target.shape[0]], dtype=torch.long)
            if all_reduce:
                all_reduce(count)
            for meter, val in (losses, loss), (top1, prec1), (top5, prec5):
                if all_reduce:
                    all_reduce(val)
                val /= count.item()
                meter.update(val.item(), count.item())