def main_worker()

in data_augmentation/my_training.py [0:0]


def main_worker(gpu, ngpus_per_node, args, ckpt_path, repo):
    global best_acc1
    args.gpu = gpu

    if not args.distributed or (args.distributed and args.rank % ngpus_per_node == 0):
        model_dir = create_repo(args, repo)

    if args.gpu is not None:
        print("Use GPU: {} for training".format(args.gpu))

    cur_device = torch.cuda.current_device()

    # create model
    if args.pretrained:
        print("=> using pre-trained model '{}'".format(args.arch))
        net = models.__dict__[args.arch](pretrained=True)
    else:
        print("=> creating model '{}'".format(args.arch))
        if args.arch == "modified_resnet18":
            net = modified_resnet18(modify=args.modify, num_classes=args.num_classes)
        else:
            net = models.__dict__[args.arch](num_classes=args.num_classes)

    if args.augerino:
        if args.inv_per_class:
            assert args.disable_at_valid
            augerino_classes = args.num_classes
        else:
            augerino_classes = 1
        if args.transfos == ["tx", "ty", "scale"]:  # special case we pass it 1 by 1
            if args.min_val:
                print("Using UniformAugEachMin")
                augerino = UniformAugEachMin(
                    transfos=args.transfos,
                    min_values=args.min_values,
                    shutvals=args.shutdown_vals,
                    num_classes=augerino_classes,
                )
            else:
                print("Using UniformAugEach")
                augerino = UniformAugEachPos(
                    transfos=args.transfos,
                    shutvals=args.shutdown_vals,
                    num_classes=augerino_classes,
                )
        else:
            if args.min_val:
                augerino = AugModuleMin(
                    transfos=args.transfos,
                    min_values=args.min_values,
                    shutvals=args.shutdown_vals,
                    num_classes=augerino_classes,
                )
            else:
                augerino = MyUniformAug(
                    transfos=args.transfos,
                    shutvals=args.shutdown_vals,
                    num_classes=augerino_classes,
                )

        augerino.set_width(
            torch.FloatTensor(args.startwidth)[None,:].repeat(augerino_classes,1)
        )
        print(augerino.width)
        if args.fixed_augerino:
            augerino.width.requires_grad=False
        model = AugAveragedModel(
            net, augerino, disabled=False, ncopies=args.ncopies, onecopy=args.onecopy
        )
    else:
        model = net

    # save initial width
    if args.augerino:
        widths = [model.aug.width.clone().detach()]
    else:
        widths = []

    model = model.cuda(device=cur_device)

    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            module=model, device_ids=[cur_device], output_device=cur_device
        )
    if (
        args.pretrained_noaugment and ckpt_path is None
    ):  # to ensure we're not restarting after preemption
        if args.distributed:
            net = torch.nn.parallel.DistributedDataParallel(
                module=net, device_ids=[cur_device], output_device=cur_device
            )
        checkpointing.restart_from_checkpoint(
            args.noaugment_path, args, state_dict=net
        )  # WARNING: Lr is not adjusted accordingly
        factor_lr = 1.0
    else:
        factor_lr = 0.1

    # define loss function (criterion) and optimizer
    if args.augerino:
        criterion = aug_losses.safe_unif_aug_loss_each
        model_param = (
            model.module.model.parameters()
            if args.distributed
            else model.model.parameters()
        )
        aug_param = (
            model.module.aug.parameters()
            if args.distributed
            else model.aug.parameters()
        )
        params = [
            {
                "name": "model",
                "params": model_param,
                "momentum": args.momentum,
                "weight_decay": args.weight_decay,
            },
            {
                "name": "aug",
                "params": aug_param,
                "momentum": args.momentum,
                "weight_decay": 0.0,
                "lr": args.lr * factor_lr,
            },
        ]
    else:
        criterion = nn.CrossEntropyLoss().cuda()
        params = [
            {
                "name": "model",
                "params": model.parameters(),
                "momentum": args.momentum,
                "weight_decay": args.weight_decay,
            }
        ]

    optimizer = torch.optim.SGD(params, args.lr)
    to_restore = {"epoch": 0, "best_acc1": 0.0, "all_acc1": [], "width": widths}
    if ckpt_path is not None:
        checkpointing.restart_from_checkpoint(
            ckpt_path,
            args,
            run_variables=to_restore,
            state_dict=model,
            optimizer=optimizer,
        )
    args.start_epoch = to_restore["epoch"]
    best_acc1 = to_restore["best_acc1"]
    all_acc1 = to_restore["all_acc1"]
    widths = to_restore["width"]
    print("Starting from Epoch", args.start_epoch)

    cudnn.benchmark = True
    # Data loading code
    traindir = os.path.join(args.data, "train")
    valdir = os.path.join(args.data, "val")

    train_loader, val_loader, train_sampler = functions.return_loader_and_sampler(
        args, traindir, valdir
    )

    if args.evaluate:
        validate(val_loader, model, criterion, args)
        # return

    for epoch in range(args.start_epoch, args.epochs):

        adjust_learning_rate(optimizer, epoch, args, factor_lr)

        if args.distributed:
            train_sampler.set_epoch(epoch)
        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, args)

        # evaluate on validation set
        acc1, acc5, val_loss = validate(val_loader, model, criterion, args)
        # remember best acc@1 and save checkpoint
        is_best = acc1 > best_acc1
        best_acc1 = max(acc1, best_acc1)
        all_acc1.append(acc1)
        if args.augerino:
            width = model.module.aug.width if args.distributed else module.aug.width
            widths.append(width.clone().detach())

        if not args.distributed or (
            args.distributed and args.rank % ngpus_per_node == 0
        ):
            save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "arch": args.arch,
                    "state_dict": model.state_dict(),
                    "best_acc1": best_acc1,
                    "acc1": acc1,
                    "acc5": acc5,
                    "all_acc1": all_acc1,
                    "optimizer": optimizer.state_dict(),
                    "val_loss": val_loss,
                    "width": widths,
                },
                is_best,
                model_dir,
            )