in eval/eval_ABX.py [0:0]
def main(argv):
args = parse_args(argv)
if args.path_checkpoint is None:
if args.file_extension == '.pt':
feature_function = load_pt
elif args.file_extension == '.npy':
feature_function = load_npy
else:
state_dict = torch.load(args.path_checkpoint)
feature_maker = load_cpc_features(state_dict)
feature_maker.cuda()
def feature_function(
x): return build_feature_from_file(x, feature_maker)
# Modes
if args.mode == 'all':
modes = ["within", "across"]
else:
modes = [args.mode]
step_feature = 1 / args.feature_size
# Get the list of sequences
seq_list = find_all_files(args.path_data, args.file_extension)
scores = ABX(feature_function, args.path_item_file,
seq_list, args.distance_mode,
step_feature, modes,
cuda=args.cuda,
max_x_across=args.max_x_across,
max_size_group=args.max_size_group)
out_dir = Path(args.path_checkpoint).parent if args.out is None \
else Path(args.out)
out_dir.mkdir(exist_ok=True)
path_score = out_dir / 'ABX_scores.json'
with open(path_score, 'w') as file:
json.dump(scores, file, indent=2)
path_args = out_dir / 'ABX_args.json'
with open(path_args, 'w') as file:
json.dump(vars(args), file, indent=2)