def main_worker()

in scripts/image_classification/train.py [0:0]


def main_worker(cfg):
    # create tensorboard and logs
    if cfg.DDP_CONFIG.GPU_WORLD_RANK == 0:
        tb_logdir = build_log_dir(cfg)
        writer = SummaryWriter(log_dir=tb_logdir)
    else:
        writer = None
    cfg.freeze()

    # create criterion
    criterion = LabelSmoothingCrossEntropy()
    if cfg.CONFIG.AUG.MIXUP > 0.:
        # smoothing is handled with mixup label transform
        criterion = SoftTargetCrossEntropy()
    elif cfg.CONFIG.AUG.LABEL_SMOOTHING > 0.:
        criterion = LabelSmoothingCrossEntropy(smoothing=cfg.CONFIG.AUG.LABEL_SMOOTHING)
    else:
        criterion = torch.nn.CrossEntropyLoss()

    # create model
    # should follow the order: build_model -> build optimizer -> deploy model
    print('Creating model: %s' % (cfg.CONFIG.MODEL.NAME))
    model = get_model(cfg)
    model_without_ddp = model

    # create optimizer
    # linear scale the learning rate according to total batch size, may not be optimal for transformer
    linear_scaled_lr = cfg.CONFIG.TRAIN.LR * cfg.CONFIG.TRAIN.BATCH_SIZE * dist.get_world_size() / 512.0
    linear_scaled_warmup_lr = cfg.CONFIG.TRAIN.WARMUP_START_LR * cfg.CONFIG.TRAIN.BATCH_SIZE * dist.get_world_size() / 512.0
    linear_scaled_min_lr = cfg.CONFIG.TRAIN.MIN_LR * cfg.CONFIG.TRAIN.BATCH_SIZE * dist.get_world_size() / 512.0
    cfg.defrost()
    cfg.CONFIG.TRAIN.LR = linear_scaled_lr    # 5e-4 -> 0.001
    cfg.CONFIG.TRAIN.WARMUP_START_LR = linear_scaled_warmup_lr
    cfg.CONFIG.TRAIN.MIN_LR = linear_scaled_min_lr
    cfg.freeze()
    optimizer = build_optimizer(cfg, model_without_ddp)

    model, optimizer, model_ema = deploy_model(model, optimizer, cfg)
    model_without_ddp = model.module
    num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print('Number of parameters in the model: %6.2fM' % (num_parameters / 1000000))

    # create dataset and dataloader
    train_loader, val_loader, train_sampler, val_sampler, mixup_fn = build_dataloader(cfg)

    # create lr scheduler
    lr_scheduler = build_scheduler(cfg, optimizer, len(train_loader))

    # resume from a checkpoint
    if cfg.CONFIG.TRAIN.RESUME:
        if cfg.CONFIG.TRAIN.RESUME.startswith('https'):
            checkpoint = torch.hub.load_state_dict_from_url(
                cfg.CONFIG.TRAIN.RESUME, map_location='cpu', check_hash=True)
        else:
            checkpoint = torch.load(cfg.CONFIG.TRAIN.RESUME, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        if 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer'])
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            cfg.defrost()
            cfg.CONFIG.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
            cfg.freeze()
            if 'amp' in checkpoint and cfg.CONFIG.TRAIN.AMP_LEVEL != "O0" and checkpoint['config'].CONFIG.TRAIN.AMP_LEVEL != "O0":
                amp.load_state_dict(checkpoint['amp'])
        print('Resume from previous checkpoint of epoch %d at %s' % (checkpoint['epoch'], cfg.CONFIG.TRAIN.RESUME))

    print('Start training...')
    start_time = time.time()
    max_accuracy = 0.0
    for epoch in range(cfg.CONFIG.TRAIN.START_EPOCH, cfg.CONFIG.TRAIN.EPOCH_NUM):
        if cfg.DDP_CONFIG.DISTRIBUTED:
            train_sampler.set_epoch(epoch)

        train_classification(cfg, model, model_ema, criterion, train_loader, optimizer, epoch, mixup_fn, lr_scheduler, writer)

        if cfg.DDP_CONFIG.GPU_WORLD_RANK == 0 and (epoch % cfg.CONFIG.LOG.SAVE_FREQ == 0 or epoch == cfg.CONFIG.TRAIN.EPOCH_NUM - 1):
            save_checkpoint(cfg, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler)

        if epoch % cfg.CONFIG.VAL.FREQ == 0 or epoch == cfg.CONFIG.TRAIN.EPOCH_NUM - 1:
            acc1, acc5, loss = validate_classification(cfg, val_loader, model, epoch, writer)
            max_accuracy = max(max_accuracy, acc1)
            print(f"Accuracy of the network: {acc1:.1f}%")
            print(f'Max accuracy: {max_accuracy:.2f}%')

    if writer is not None:
        writer.close()

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))