def main_worker()

in data_augmentation/test.py [0:0]


def main_worker(gpu, ngpus_per_node, args, ckpt_path):
    args.gpu = gpu

    if args.gpu is not None:
        print("Use GPU: {} for testing".format(args.gpu))

    cur_device = torch.cuda.current_device()
    # create model
    print("=> creating model '{}'".format(args.arch))
    net = models.__dict__[args.arch](num_classes=args.num_classes)

    if args.augerino:
        if args.inv_per_class:
            augerino_classes = args.num_classes
        else:
            augerino_classes = 1

        if args.transfos == ["tx", "ty", "scale"]:  # special case we pass it 1 by 1
            if args.min_val:
                print("Using UniformAugEachMin")
                augerino = UniformAugEachMin(
                    transfos=args.transfos,
                    min_values=args.min_values,
                    shutvals=args.shutdown_vals,
                    num_classes=augerino_classes,
                )
            else:
                print("Using UniformAugEach")
                augerino = UniformAugEachPos(
                    transfos=args.transfos,
                    shutvals=args.shutdown_vals,
                    num_classes=augerino_classes,
                )
        else:
            if args.min_val:
                augerino = AugModuleMin(
                    transfos=args.transfos,
                    min_values=args.min_values,
                    shutvals=args.shutdown_vals,
                    num_classes=augerino_classes,
                )
            else:
                augerino = MyUniformAug(
                    transfos=args.transfos,
                    shutvals=args.shutdown_vals,
                    num_classes=augerino_classes,
                )
        model = AugAveragedModel(net, augerino, ncopies=args.ncopies)
    else:
        model = net

    model = model.cuda(device=cur_device)

    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            module=model, device_ids=[cur_device], output_device=cur_device
        )

    to_restore = {"epoch":0.0, "all_acc1": 0.0, "best_acc1": 0.0}
    checkpointing.restart_from_checkpoint(ckpt_path, args, 
                    run_variables=to_restore,state_dict=model)

    best_acc1 = to_restore["best_acc1"]
    model_acc1 = to_restore["all_acc1"][-1]
    print("Best Acc1 was", best_acc1)
    print("Model Acc1 was", model_acc1)

    # Data loading code
    traindir = os.path.join(args.data, "train")
    valdir = os.path.join(args.data, "val")
    _, test_loader, _ = functions_bis.return_loader_and_sampler(args, traindir, valdir, return_train=False) 

    compute_per_sample = True #TODO: better than manually set 
    if compute_per_sample:
        assert len(args.test_seeds)==1
    top1s = []
    top5s = []
    tops1class = []
    per_img_dict = {}
    for seed in args.test_seeds:
        # Set seed for testing 
        random.seed(seed)
        torch.manual_seed(seed)

        top1, top5, top1class = test(test_loader, model, args, per_img_dict, args.no_aug_test, compute_per_sample)
        print('{:.5f},{:.5f},{:.5f}'.format(top1, top5, top1class.mean()))
        top1s.append(torch.FloatTensor([top1]))
        top5s.append(torch.FloatTensor([top5]))
        tops1class.append(top1class[None,:])
    print(top1s)
    top1s = torch.cat(top1s).mean()
    top5s = torch.cat(top5s).mean()
    tops1class = torch.cat(tops1class,dim=0).mean(0)

    ckpt = {}
    ckpt['top1'] = top1s
    ckpt['top5'] = top5s
    ckpt['top1class'] = tops1class
    name = 'test3_{0}_{1}_{2}.pth'.format(args.augment_valid, args.scale_mag, args.no_aug_test)
    ckpt_folder = '/'.join(ckpt_path.split('/')[:-1])
    if not args.distributed or (
            args.distributed and args.rank % ngpus_per_node == 0
        ):
        torch.save(ckpt, os.path.join(ckpt_folder, name))
        if compute_per_sample:
            name = 'per_sample_{0}_{1}_{2}.pth'.format(args.augment_valid, args.scale_mag, args.no_aug_test)
            torch.save(per_img_dict, os.path.join(ckpt_folder, name))