def main()

in src/snn_fine_tune.py [0:0]


def main(args):

    # -- META
    model_name = args['meta']['model_name']
    load_checkpoint = args['meta']['load_checkpoint']
    copy_data = args['meta']['copy_data']
    output_dim = args['meta']['output_dim']
    use_pred_head = args['meta']['use_pred_head']
    use_fp16 = args['meta']['use_fp16']
    device = torch.device(args['meta']['device'])
    torch.cuda.set_device(device)

    # -- DATA
    unlabeled_frac = args['data']['unlabeled_frac']
    label_smoothing = args['data']['label_smoothing']
    normalize = args['data']['normalize']
    root_path = args['data']['root_path']
    image_folder = args['data']['image_folder']
    dataset_name = args['data']['dataset']
    subset_path = args['data']['subset_path']
    unique_classes = args['data']['unique_classes_per_rank']
    data_seed = args['data']['data_seed']

    # -- CRITERTION
    classes_per_batch = args['criterion']['classes_per_batch']
    supervised_views = args['criterion']['supervised_views']
    batch_size = args['criterion']['supervised_batch_size']
    temperature = args['criterion']['temperature']

    # -- OPTIMIZATION
    wd = float(args['optimization']['weight_decay'])
    num_epochs = args['optimization']['epochs']
    use_lars = args['optimization']['use_lars']
    warmup = args['optimization']['warmup']
    start_lr = args['optimization']['start_lr']
    ref_lr = args['optimization']['lr']
    final_lr = args['optimization']['final_lr']
    momentum = args['optimization']['momentum']
    nesterov = args['optimization']['nesterov']

    # -- LOGGING
    folder = args['logging']['folder']
    tag = args['logging']['write_tag']
    r_file_enc = args['logging']['pretrain_path']

    # -- log/checkpointing paths
    r_enc_path = os.path.join(folder, r_file_enc)
    w_enc_path = os.path.join(folder, f'{tag}-fine-tune-SNN.pth.tar')

    # -- init distributed
    world_size, rank = init_distributed()
    logger.info(f'initialized rank/world-size: {rank}/{world_size}')

    # -- init loss
    suncet = init_suncet_loss(
        num_classes=classes_per_batch,
        batch_size=batch_size*supervised_views,
        world_size=world_size,
        rank=rank,
        temperature=temperature,
        device=device)
    labels_matrix = make_labels_matrix(
        num_classes=classes_per_batch,
        s_batch_size=batch_size,
        world_size=world_size,
        device=device,
        unique_classes=unique_classes,
        smoothing=label_smoothing)

    # -- make data transforms
    transform, init_transform = make_transforms(
        dataset_name=dataset_name,
        subset_path=subset_path,
        unlabeled_frac=unlabeled_frac,
        training=True,
        split_seed=data_seed,
        basic_augmentations=True,
        normalize=normalize)
    (data_loader,
     dist_sampler) = init_data(
         dataset_name=dataset_name,
         transform=transform,
         init_transform=init_transform,
         supervised_views=supervised_views,
         u_batch_size=None,
         stratify=True,
         s_batch_size=batch_size,
         classes_per_batch=classes_per_batch,
         unique_classes=unique_classes,
         world_size=world_size,
         rank=rank,
         root_path=root_path,
         image_folder=image_folder,
         training=True,
         copy_data=copy_data)

    # -- rough estimate of labeled imgs per class used to set the number of
    #    fine-tuning iterations
    imgs_per_class = int(1300*(1.-unlabeled_frac)) if 'imagenet' in dataset_name else int(5000*(1.-unlabeled_frac))
    dist_sampler.set_inner_epochs(imgs_per_class//batch_size)

    ipe = len(data_loader)
    logger.info(f'initialized data-loader (ipe {ipe})')

    # -- init model and optimizer
    scaler = torch.cuda.amp.GradScaler(enabled=use_fp16)
    encoder, optimizer, scheduler = init_model(
        device=device,
        training=True,
        r_enc_path=r_enc_path,
        iterations_per_epoch=ipe,
        world_size=world_size,
        start_lr=start_lr,
        ref_lr=ref_lr,
        num_epochs=num_epochs,
        output_dim=output_dim,
        model_name=model_name,
        warmup_epochs=warmup,
        use_pred_head=use_pred_head,
        use_fp16=use_fp16,
        wd=wd,
        final_lr=final_lr,
        momentum=momentum,
        nesterov=nesterov,
        use_lars=use_lars)

    best_acc, val_top1 = None, None
    start_epoch = 0
    # -- load checkpoint
    if load_checkpoint:
        encoder, optimizer, scaler, scheduler, start_epoch, best_acc = load_from_path(
            r_path=w_enc_path,
            encoder=encoder,
            opt=optimizer,
            scaler=scaler,
            sched=scheduler,
            device=device,
            use_fp16=use_fp16,
            ckp=True)

    for epoch in range(start_epoch, num_epochs):

        def train_step():
            # -- update distributed-data-loader epoch
            dist_sampler.set_epoch(epoch)

            for i, data in enumerate(data_loader):
                imgs = torch.cat([s.to(device) for s in data[:-1]], 0)
                labels = torch.cat([labels_matrix for _ in range(supervised_views)])
                with torch.cuda.amp.autocast(enabled=use_fp16):
                    optimizer.zero_grad()
                    z = encoder(imgs)
                    loss = suncet(z, labels)
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
                scheduler.step()
                if i % log_freq == 0:
                    logger.info('[%d, %5d] (loss: %.3f)' % (epoch + 1, i, loss))

        with torch.no_grad():
            with nostdout():
                val_top1, _ = val_run(
                    pretrained=copy.deepcopy(encoder),
                    subset_path=subset_path,
                    unlabeled_frac=unlabeled_frac,
                    dataset_name=dataset_name,
                    root_path=root_path,
                    image_folder=image_folder,
                    use_pred=use_pred_head,
                    normalize=normalize,
                    split_seed=data_seed)
        logger.info('[%d] (val: %.3f%%)' % (epoch + 1, val_top1))
        train_step()

        # -- logging/checkpointing
        if (rank == 0) and ((best_acc is None) or (best_acc < val_top1)):
            best_acc = val_top1
            save_dict = {
                'encoder': encoder.state_dict(),
                'opt': optimizer.state_dict(),
                'sched': scheduler.state_dict(),
                'epoch': epoch + 1,
                'unlabel_prob': unlabeled_frac,
                'world_size': world_size,
                'batch_size': batch_size,
                'best_top1_acc': best_acc,
                'lr': ref_lr,
                'amp': scaler.state_dict()
            }
            torch.save(save_dict, w_enc_path)

    logger.info('[%d] (best-val: %.3f%%)' % (epoch + 1, best_acc))