def main()

in src/fine_tune.py [0:0]


def main(args):

    # -- META
    model_name = args['meta']['model_name']
    port = args['meta']['master_port']
    load_checkpoint = args['meta']['load_checkpoint']
    training = args['meta']['training']
    copy_data = args['meta']['copy_data']
    use_fp16 = args['meta']['use_fp16']
    device = torch.device(args['meta']['device'])
    torch.cuda.set_device(device)

    # -- DATA
    unlabeled_frac = args['data']['unlabeled_frac']
    normalize = args['data']['normalize']
    root_path = args['data']['root_path']
    image_folder = args['data']['image_folder']
    dataset_name = args['data']['dataset']
    subset_path = args['data']['subset_path']
    num_classes = args['data']['num_classes']
    data_seed = None
    if 'cifar10' in dataset_name:
        data_seed = args['data']['data_seed']
    crop_scale = (0.5, 1.0) if 'cifar10' in dataset_name else (0.08, 1.0)

    # -- OPTIMIZATION
    wd = float(args['optimization']['weight_decay'])
    ref_lr = args['optimization']['lr']
    use_lars = args['optimization']['use_lars']
    zero_init = args['optimization']['zero_init']
    num_epochs = args['optimization']['epochs']

    # -- LOGGING
    folder = args['logging']['folder']
    tag = args['logging']['write_tag']
    r_file_enc = args['logging']['pretrain_path']

    # -- log/checkpointing paths
    r_enc_path = os.path.join(folder, r_file_enc)
    w_enc_path = os.path.join(folder, f'{tag}-fine-tune.pth.tar')

    # -- init distributed
    world_size, rank = init_distributed(port)
    logger.info(f'initialized rank/world-size: {rank}/{world_size}')

    # -- optimization/evaluation params
    if training:
        batch_size = 256
    else:
        batch_size = 16
        unlabeled_frac = 0.0
        load_checkpoint = True
        num_epochs = 1

    # -- init loss
    criterion = torch.nn.CrossEntropyLoss()

    # -- make train data transforms and data loaders/samples
    transform, init_transform = make_transforms(
        dataset_name=dataset_name,
        subset_path=subset_path,
        unlabeled_frac=unlabeled_frac,
        training=training,
        crop_scale=crop_scale,
        split_seed=data_seed,
        basic_augmentations=True,
        normalize=normalize)
    (data_loader,
     dist_sampler) = init_data(
         dataset_name=dataset_name,
         transform=transform,
         init_transform=init_transform,
         u_batch_size=None,
         s_batch_size=batch_size,
         classes_per_batch=None,
         world_size=world_size,
         rank=rank,
         root_path=root_path,
         image_folder=image_folder,
         training=training,
         copy_data=copy_data)

    ipe = len(data_loader)
    logger.info(f'initialized data-loader (ipe {ipe})')

    # -- make val data transforms and data loaders/samples
    val_transform, val_init_transform = make_transforms(
        dataset_name=dataset_name,
        subset_path=subset_path,
        unlabeled_frac=-1,
        training=True,
        basic_augmentations=True,
        force_center_crop=True,
        normalize=normalize)
    (val_data_loader,
     val_dist_sampler) = init_data(
         dataset_name=dataset_name,
         transform=val_transform,
         init_transform=val_init_transform,
         u_batch_size=None,
         s_batch_size=batch_size,
         classes_per_batch=None,
         world_size=1,
         rank=0,
         root_path=root_path,
         image_folder=image_folder,
         training=True,
         copy_data=copy_data)
    logger.info(f'initialized val data-loader (ipe {len(val_data_loader)})')

    # -- init model and optimizer
    scaler = torch.cuda.amp.GradScaler(enabled=use_fp16)
    encoder, optimizer, scheduler = init_model(
        device=device,
        device_str=args['meta']['device'],
        num_classes=num_classes,
        training=training,
        use_fp16=use_fp16,
        r_enc_path=r_enc_path,
        iterations_per_epoch=ipe,
        world_size=world_size,
        ref_lr=ref_lr,
        weight_decay=wd,
        use_lars=use_lars,
        zero_init=zero_init,
        num_epochs=num_epochs,
        model_name=model_name)

    best_acc = None
    start_epoch = 0
    # -- load checkpoint
    if not training or load_checkpoint:
        encoder, optimizer, scheduler, start_epoch, best_acc = load_from_path(
            r_path=w_enc_path,
            encoder=encoder,
            opt=optimizer,
            sched=scheduler,
            scaler=scaler,
            device_str=args['meta']['device'],
            use_fp16=use_fp16)
    if not training:
        logger.info('putting model in eval mode')
        encoder.eval()
        logger.info(sum(p.numel() for n, p in encoder.named_parameters()
                        if p.requires_grad and ('fc' not in n)))
        start_epoch = 0

    for epoch in range(start_epoch, num_epochs):

        def train_step():
            # -- update distributed-data-loader epoch
            dist_sampler.set_epoch(epoch)
            top1_correct, top5_correct, total = 0, 0, 0
            for i, data in enumerate(data_loader):
                with torch.cuda.amp.autocast(enabled=use_fp16):
                    inputs, labels = data[0].to(device), data[1].to(device)
                    outputs = encoder(inputs)
                    loss = criterion(outputs, labels)
                total += inputs.shape[0]
                top5_correct += float(outputs.topk(5, dim=1).indices.eq(labels.unsqueeze(1)).sum())
                top1_correct += float(outputs.max(dim=1).indices.eq(labels).sum())
                top1_acc = 100. * top1_correct / total
                top5_acc = 100. * top5_correct / total
                if training:
                    scaler.scale(loss).backward()
                    scaler.step(optimizer)
                    scaler.update()
                    scheduler.step()
                    optimizer.zero_grad()
                if i % log_freq == 0:
                    logger.info('[%d, %5d] %.3f%% %.3f%% (loss: %.3f)'
                                % (epoch + 1, i, top1_acc, top5_acc, loss))
            return 100. * top1_correct / total

        def val_step():
            val_encoder = copy.deepcopy(encoder).eval()
            top1_correct, total = 0, 0
            for i, data in enumerate(val_data_loader):
                inputs, labels = data[0].to(device), data[1].to(device)
                outputs = val_encoder(inputs)
                total += inputs.shape[0]
                top1_correct += float(outputs.max(dim=1).indices.eq(labels).sum())
                top1_acc = 100. * top1_correct / total

            logger.info('[%d, %5d] %.3f%%' % (epoch + 1, i, top1_acc))
            return 100. * top1_correct / total

        train_top1 = 0.
        train_top1 = train_step()
        with torch.no_grad():
            val_top1 = val_step()

        log_str = 'train:' if training else 'test:'
        logger.info('[%d] (%s: %.3f%%) (val: %.3f%%)'
                    % (epoch + 1, log_str, train_top1, val_top1))

        # -- logging/checkpointing
        if training and (rank == 0) and ((best_acc is None)
                                         or (best_acc < val_top1)):
            best_acc = val_top1
            save_dict = {
                'encoder': encoder.state_dict(),
                'opt': optimizer.state_dict(),
                'sched': scheduler.state_dict(),
                'epoch': epoch + 1,
                'unlabel_prob': unlabeled_frac,
                'world_size': world_size,
                'best_top1_acc': best_acc,
                'batch_size': batch_size,
                'lr': ref_lr,
                'amp': scaler.state_dict()
            }
            torch.save(save_dict, w_enc_path)

    return train_top1, val_top1