def main()

in main_stica.py [0:0]


def main():

    # parse arguments
    global args
    parser = parse_arguments()
    args = parser.parse_args()

    # exp setup: logger, distributed mode and seeds
    init_distributed_mode(args)
    init_signal_handler()
    fix_random_seeds(args.seed)
    logger, training_stats = initialize_exp(args, "epoch", "loss")
    if args.rank == 0:
        writer = SummaryWriter(args.dump_path)
        writer.add_text(
            'args',
            " \n".join(['%s : %s' % (arg, getattr(args, arg)) for arg in vars(args)]), 
            0
        )
    else:
        writer = None

    # Spec Augment params: []
    if args.audio_augtype == 'mild':
        aug_audio = [1, 1, 2, 5]
    elif args.audio_augtype == 'medium':
        aug_audio = [1, 1, 3, 6]
    elif args.audio_augtype == 'heavy':
        aug_audio = [2, 2, 3, 6]
    else:
        aug_audio = []

    train_dataset = AVideoDataset(
        ds_name=args.dataset_name,
        mode='train',
        root_dir=args.root_dir,
        decode_audio=True,
        args=args
    )

    sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        sampler=sampler,
        batch_size=args.batch_size,
        num_workers=args.workers,
        pin_memory=True,
        drop_last=True
    )
    logger.info("Building data done with {} images loaded.".format(
        len(train_dataset)))

    # build model
    model = Stica_TransformerFMCrop(
        vid_base_arch='r2plus1d_18',
        aud_base_arch='resnet9',
        pretrained=False,
        norm_feat=True,
        use_mlp=True,
        num_classes=256, # embedding dimension
        args=args
    )

    # synchronize batch norm layers
    model = nn.SyncBatchNorm.convert_sync_batchnorm(model)

    # copy model to GPU
    model = model.cuda()
    if args.rank == 0:
        logger.info(model)
    logger.info("Building model done.")

    # build optimizer
    optimizer = torch.optim.SGD(
        model.parameters(),
        lr=args.base_lr,
        momentum=0.9,
        weight_decay=args.wd,
    )
    if args.use_warmup_scheduler:
        warmup_lr_schedule = np.linspace(
            args.start_warmup, args.base_lr, len(train_loader) * args.warmup_epochs)
        iters = np.arange(len(train_loader) * (args.epochs - args.warmup_epochs))
        if args.use_lr_scheduler:
            cosine_lr_schedule = np.array(
                [args.final_lr + 0.5 * (args.base_lr - args.final_lr) * (1 + \
                    math.cos(math.pi * t / (len(train_loader) * (args.epochs - args.warmup_epochs)))) 
                    for t in iters
                ])
            lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule))
        else:
            constant_schedule = np.array([args.base_lr for t in iters])
            lr_schedule = np.concatenate((warmup_lr_schedule, constant_schedule))
    logger.info("Building optimizer done.")

    # init mixed precision
    if args.use_fp16:
        model, optimizer = apex.amp.initialize(model, optimizer, opt_level="O1")
        logger.info("Initializing mixed precision done.")

    # wrap model
    model = nn.parallel.DistributedDataParallel(
        model,
        device_ids=[args.gpu_to_work_on],
        find_unused_parameters=True,
    )

    # optionally resume from a checkpoint
    to_restore = {"epoch": 0}
    restart_from_checkpoint(
        os.path.join(args.dump_path, "checkpoint.pth.tar"),
        run_variables=to_restore,
        state_dict=model,
        optimizer=optimizer,
        amp=apex.amp if args.use_fp16 else None,
    )
    start_epoch = to_restore["epoch"]

    # Set CuDNN benhcmark
    cudnn.benchmark = True

    for epoch in range(start_epoch, args.epochs):

        # train the network for one epoch
        logger.info("============ Starting epoch %i ... ============" % epoch)

        # set sampler
        train_loader.sampler.set_epoch(epoch)

        # train the network
        scores = train(
            train_loader, model, optimizer, epoch, lr_schedule, writer)
        training_stats.update(scores)
        if args.rank == 0 and writer:
            writer.add_scalar('pretrain/epoch', epoch, epoch)

        # save checkpoints
        if args.rank == 0:
            save_dict = {
                "epoch": epoch + 1,
                "state_dict": model.state_dict(),
                "optimizer": optimizer.state_dict(),
            }
            if args.use_fp16:
                save_dict["amp"] = apex.amp.state_dict()
            torch.save(
                save_dict,
                os.path.join(
                    args.dump_path, 
                    "checkpoint.pth.tar"
                ),
            )
            if epoch % args.checkpoint_freq == 0 or epoch == args.epochs - 1:
                shutil.copyfile(
                    os.path.join(
                        args.dump_path, 
                        "checkpoint.pth.tar"
                    ),
                    os.path.join(
                        args.dump_checkpoints, 
                        "ckp-" + str(epoch) + ".pth"
                    ),
                )