def main()

in eval_video.py [0:0]


def main(args, writer):

    # Create Logger
    logger, training_stats = initialize_exp(
        args, "epoch", "loss", "prec1", "prec5", "loss_val", "prec1_val", "prec5_val"
    )

    # Set CudNN benchmark
    torch.backends.cudnn.benchmark = True

    # Load model
    logger.info("Loading model")
    model = load_model(
        model_type=args.model,
        vid_base_arch=args.vid_base_arch,
        aud_base_arch=args.aud_base_arch,
        pretrained=args.pretrained,
        norm_feat=False,
        use_mlp=args.use_mlp,
        num_classes=256,
        args=args,
    )

    # Load model weights
    weight_path_type = type(args.weights_path)
    if weight_path_type == str:
        weight_path_not_none = args.weights_path != 'None'
    else:
        weight_path_not_none = args.weights_path is not None
    if not args.pretrained and weight_path_not_none:
        logger.info("Loading model weights")
        if os.path.exists(args.weights_path):
            ckpt_dict = torch.load(args.weights_path)
            try:
                model_weights = ckpt_dict["state_dict"]
            except:
                model_weights = ckpt_dict["model"]
            epoch = ckpt_dict["epoch"]
            logger.info(f"Epoch checkpoint: {epoch}")
            load_model_parameters(model, model_weights)
    logger.info(f"Loading model done")

    # Add FC layer to model for fine-tuning or feature extracting
    model = load_model_finetune(
        args,
        model.video_network.base,
        pooling_arch=model.video_pooling if args.agg_model else None,
        num_ftrs=model.encoder_dim,
        num_classes=NUM_CLASSES[args.dataset],
        use_dropout=args.use_dropout, 
        use_bn=args.use_bn,
        use_l2_norm=args.use_l2_norm,
        dropout=0.9,
        agg_model=args.agg_model,
    )

    # Create DataParallel model
    model = model.cuda()
    model = torch.nn.DataParallel(model)
    model_without_ddp = model.module

    # Get params for optimization
    params = []
    if args.feature_extract: # feature_extract only classifer
        logger.info("Getting params for feature-extracting")
        for name, param in model_without_ddp.classifier.named_parameters():
            logger.info((name, param.shape))
            params.append(
                {
                    'params': param, 
                    'lr': args.head_lr, 
                    'weight_decay': args.weight_decay
                })
    else: # finetune
        logger.info("Getting params for finetuning")
        for name, param in model_without_ddp.classifier.named_parameters():
            logger.info((name, param.shape))
            params.append(
                {
                    'params': param, 
                    'lr': args.head_lr, 
                    'weight_decay': args.weight_decay
                })
        for name, param in model_without_ddp.base.named_parameters():
            logger.info((name, param.shape))
            params.append(
                {   
                    'params': param, 
                    'lr': args.base_lr, 
                    'weight_decay': args.wd_base
                })
        if args.agg_model:
            logger.info("Adding pooling arch params to be optimized")
            for name, param in model_without_ddp.pooling_arch.named_parameters():
                if param.requires_grad and param.dim() >= 1:
                    logger.info(f"Adding {name}({param.shape}), wd: {args.wd_tsf}")
                    params.append(
                        {
                            'params': param, 
                            'lr': args.tsf_lr, 
                            'weight_decay': args.wd_tsf
                        })
                else:
                    logger.info(f"Not adding {name} to be optimized")


    logger.info('\n===========Check Grad============')
    for name, param in model_without_ddp.named_parameters():
        logger.info((name, param.requires_grad))
    logger.info('=================================\n')

    logger.info("Creating AV Datasets")
    dataset = AVideoDataset(
        ds_name=args.dataset,
        root_dir=args.root_dir,
        mode='train',
        num_train_clips=args.train_clips_per_video,
        decode_audio=False,
        center_crop=False,
        fold=args.fold,
        ucf101_annotation_path=args.ucf101_annotation_path,
        hmdb51_annotation_path=args.hmdb51_annotation_path,
        args=args,
    )
    dataset_test = AVideoDataset(
        ds_name=args.dataset,
        root_dir=args.root_dir,
        mode='test',
        decode_audio=False,
        num_spatial_crops=args.num_spatial_crops,
        num_ensemble_views=args.val_clips_per_video,
        ucf101_annotation_path=args.ucf101_annotation_path,
        hmdb51_annotation_path=args.hmdb51_annotation_path,
        fold=args.fold,
        args=args,
    )

    # Creating dataloaders
    logger.info("Creating data loaders")
    data_loader = torch.utils.data.DataLoader(
        dataset, 
        batch_size=args.batch_size,
        num_workers=args.workers,
        pin_memory=True, 
        drop_last=True,
        shuffle=True
    )
    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, 
        batch_size=args.batch_size,
        num_workers=args.workers,
        pin_memory=True, 
        drop_last=False
    )

    # linearly scale LR and set up optimizer
    logger.info(f"Using SGD with lr: {args.head_lr}, wd: {args.weight_decay}")
    optimizer = torch.optim.SGD(
        params,
        lr=args.head_lr, 
        momentum=args.momentum, 
        weight_decay=args.weight_decay
    )

    # Multi-step LR scheduler
    if args.use_scheduler:
        milestones = [int(lr) - args.lr_warmup_epochs for lr in args.lr_milestones.split(',')]
        logger.info(f"Num. of Epochs: {args.epochs}, Milestones: {milestones}")
        if args.lr_warmup_epochs > 0:
            logger.info(f"Using scheduler with {args.lr_warmup_epochs} warmup epochs")
            scheduler_step = torch.optim.lr_scheduler.MultiStepLR(
                optimizer, 
                milestones=milestones, 
                gamma=args.lr_gamma
            )
            lr_scheduler = GradualWarmupScheduler(
                optimizer, 
                multiplier=8,
                total_epoch=args.lr_warmup_epochs, 
                after_scheduler=scheduler_step
            )
        else: # no warmp, just multi-step
            logger.info("Using scheduler w/out warmup")
            lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer, 
                milestones=milestones, 
                gamma=args.lr_gamma
            )
    else:
        lr_scheduler = None

    # Checkpointing
    if args.resume:
        ckpt_path = os.path.join(args.output_dir, 'checkpoints', 'checkpoint.pth')
        checkpoint = torch.load(ckpt_path, map_location='cpu')
        model_without_ddp.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        if lr_scheduler is not None:
            lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        args.start_epoch = checkpoint['epoch']
        logger.info(f"Resuming from epoch: {args.start_epoch}")

    # Only perform evalaution
    if args.test_only:
        scores_val = evaluate(
            model, 
            data_loader_test,
            epoch=args.start_epoch, 
            writer=writer,
            ds=args.dataset,
        )
        _, vid_acc1, vid_acc5 = scores_val
        return vid_acc1, vid_acc5, args.start_epoch

    start_time = time.time()
    best_vid_acc_1 = -1
    best_vid_acc_5 = -1
    best_epoch = 0
    for epoch in range(args.start_epoch, args.epochs):
        logger.info(f'Start training epoch: {epoch}')
        scores = train(
            model, 
            optimizer, 
            data_loader,
            epoch, 
            writer=writer,
            ds=args.dataset,
        )
        logger.info(f'Start evaluating epoch: {epoch}')
        lr_scheduler.step()
        if (epoch % 1 == 0) and epoch > 6:
            scores_val = evaluate(
                model, 
                data_loader_test,
                epoch=epoch,
                writer=writer,
                ds=args.dataset,
            )
            _, vid_acc1, vid_acc5 = scores_val
            training_stats.update(scores + scores_val)
            if vid_acc1 > best_vid_acc_1:
                best_vid_acc_1 = vid_acc1
                best_vid_acc_5 = vid_acc5
                best_epoch = epoch
        if args.output_dir:
            logger.info(f'Saving checkpoint to: {args.output_dir}')
            save_checkpoint(args, epoch, model, optimizer, lr_scheduler, ckpt_freq=1)
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info(f'Training time {total_time_str}')
    return best_vid_acc_1, best_vid_acc_5, best_epoch