def main()

in eval/eval_ABX.py [0:0]


def main(argv):

    args = parse_args(argv)

    if args.path_checkpoint is None:
        if args.file_extension == '.pt':
            feature_function = load_pt
        elif args.file_extension == '.npy':
            feature_function = load_npy
    else:
        state_dict = torch.load(args.path_checkpoint)
        feature_maker = load_cpc_features(state_dict)
        feature_maker.cuda()
        def feature_function(
            x): return build_feature_from_file(x, feature_maker)

    # Modes
    if args.mode == 'all':
        modes = ["within", "across"]
    else:
        modes = [args.mode]

    step_feature = 1 / args.feature_size

    # Get the list of sequences
    seq_list = find_all_files(args.path_data, args.file_extension)

    scores = ABX(feature_function, args.path_item_file,
                 seq_list, args.distance_mode,
                 step_feature, modes,
                 cuda=args.cuda,
                 max_x_across=args.max_x_across,
                 max_size_group=args.max_size_group)

    out_dir = Path(args.path_checkpoint).parent if args.out is None \
        else Path(args.out)
    out_dir.mkdir(exist_ok=True)

    path_score = out_dir / 'ABX_scores.json'
    with open(path_score, 'w') as file:
        json.dump(scores, file, indent=2)

    path_args = out_dir / 'ABX_args.json'
    with open(path_args, 'w') as file:
        json.dump(vars(args), file, indent=2)