def main()

in video_retrieval.py [0:0]


def main(args, logger=None):

    # Get model and datasets
    model, dataset, dataset_test = init(args, 
        get_video_encoder_only=True, logger=logger)

    # Get train features
    train_features, train_vid_indices, train_labels = load_or_get_features(
        args, dataset, model, 
        logger=logger, mode='train', get_audio=args.get_audio
    )

    # Get val features
    val_features, val_vid_indices, val_labels = load_or_get_features(
        args, dataset_test, model, 
        logger=logger, mode='test', get_audio=args.get_audio
    )

    # Average features to get mean feat per video
    print("Averaging features")
    train_features, train_vid_indices, train_labels = average_features(
        args, train_features, train_vid_indices, train_labels, 
        get_audio=args.get_audio, aud_features=None, logger=logger
    )
    val_features, val_vid_indices, val_labels = average_features(
        args, val_features, val_vid_indices, val_labels, 
        get_audio=args.get_audio, aud_features=None, logger=logger
    )

    # Get retrieval benchmarks
    retrieval(
        train_features, 
        train_labels,
        train_vid_indices,
        val_features, 
        val_labels, 
        val_vid_indices,
        train_aud_features=None, 
        val_aud_features=None, 
        task='v-v'
    )