def main()

in scripts/train_imagenet.py [0:0]


def main():
    global args, best_prec1, logger, conf, tb
    args = parser.parse_args()

    torch.cuda.set_device(args.local_rank)

    try:
        world_size = int(os.environ["WORLD_SIZE"])
        distributed = world_size > 1
    except:
        distributed = False
        world_size = 1

    if distributed:
        dist.init_process_group(backend=args.dist_backend, init_method="env://")

    rank = 0 if not distributed else dist.get_rank()
    init_logger(rank, args.log_dir)
    tb = SummaryWriter(args.log_dir) if rank == 0 else None

    # Load configuration
    conf = config.load_config(args.config)

    # Create model
    model_params = utils.get_model_params(conf["network"])
    model = models.__dict__["net_" + conf["network"]["arch"]](**model_params)

    model.cuda()
    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank
        )
    else:
        model = SingleGPU(model)

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer, scheduler = utils.create_optimizer(conf["optimizer"], model)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            logger.info("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint["epoch"]
            best_prec1 = checkpoint["best_prec1"]
            model.load_state_dict(checkpoint["state_dict"])
            optimizer.load_state_dict(checkpoint["optimizer"])
            logger.info(
                "=> loaded checkpoint '{}' (epoch {})".format(
                    args.resume, checkpoint["epoch"]
                )
            )
        else:
            logger.warning("=> no checkpoint found at '{}'".format(args.resume))
    else:
        init_weights(model)
        args.start_epoch = 0

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, "train")
    valdir = os.path.join(args.data, "val")

    train_transforms, val_transforms = utils.create_transforms(conf["input"])
    train_dataset = datasets.ImageFolder(traindir, transforms.Compose(train_transforms))

    if distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=conf["optimizer"]["batch_size"] // world_size,
        shuffle=(train_sampler is None),
        num_workers=args.workers,
        pin_memory=True,
        sampler=train_sampler,
    )

    val_dataset = datasets.ImageFolder(valdir, transforms.Compose(val_transforms))
    val_loader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=conf["optimizer"]["batch_size"] // world_size,
        shuffle=False,
        num_workers=args.workers,
        pin_memory=True,
        sampler=TestDistributedSampler(val_dataset),
    )

    if args.evaluate:
        utils.validate(
            val_loader,
            model,
            criterion,
            print_freq=args.print_freq,
            tb=tb,
            logger=logger.info,
        )
        return

    for epoch in range(args.start_epoch, conf["optimizer"]["schedule"]["epochs"]):
        if distributed:
            train_sampler.set_epoch(epoch)

        # train for one epoch
        train(train_loader, model, criterion, optimizer, scheduler, epoch)

        # evaluate on validation set
        prec1 = utils.validate(
            val_loader,
            model,
            criterion,
            it=epoch * len(train_loader),
            print_freq=args.print_freq,
            tb=tb,
            logger=logger.info,
        )

        # remember best prec@1 and save checkpoint
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        if rank == 0:
            save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "arch": conf["network"]["arch"],
                    "state_dict": model.state_dict(),
                    "best_prec1": best_prec1,
                    "optimizer": optimizer.state_dict(),
                },
                is_best,
                args.log_dir,
            )