libraries/python/classification_compare.py (138 lines of code) (raw):

#!/usr/bin/env python ############################################################################## # Copyright 2017-present, Facebook, Inc. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. ############################################################################## # this library is to compare the output of the benchmark and the golden output # for image classification tasks, if the golden is 1, expecting the benchmark # is the closest to that. from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import argparse import json import os import numpy as np parser = argparse.ArgumentParser(description="Output compare") parser.add_argument( "--benchmark-output", required=True, help="The output of the benchmark." ) parser.add_argument("--labels", required=True, help="The golden output.") parser.add_argument( "--metric-keyword", help="The keyword prefix each metric so that the harness can parse.", ) parser.add_argument("--result-file", help="Write the prediction result to a file.") parser.add_argument( "--top", type=int, default=1, help="Integer indicating whether it is a top one or top five.", ) parser.add_argument("--name", required=True, help="Specify the type of the metric.") class OutputCompare(object): def __init__(self): self.args = parser.parse_args() assert os.path.isfile( self.args.benchmark_output ), "Benchmark output file {} doesn't exist".format(self.args.benchmark_output) assert os.path.isfile(self.args.labels), "Labels file {} doesn't exist".format( self.args.labels ) def getData(self, filename): num_entries = 0 content_list = [] with open(filename, "r") as f: line = f.readline() dim_str = line while line != "": assert dim_str == line, "The dimensions do not match" num_entries = num_entries + 1 dims_list = [int(dim.strip()) for dim in line.strip().split(",")] line = f.readline().strip() content_list.extend([float(entry.strip()) for entry in line.split(",")]) line = f.readline() dims_list.insert(0, num_entries) dims = np.asarray(dims_list) content = np.asarray(content_list) data = np.reshape(content, dims) # reshape to two dimension array benchmark_data = data.reshape((-1, data.shape[-1])) return benchmark_data.tolist(), dims_list def writeOneResult(self, values, data, metric, unit): entry = { "type": self.args.name, "values": values, "summary": { "num_runs": len(values), "p0": data, "p10": data, "p50": data, "p90": data, "p100": data, "mean": data, }, "unit": unit, "metric": metric, } s = json.dumps(entry, sort_keys=True) if self.args.metric_keyword: s = self.args.metric_keyword + " " + s print(s) return entry def writeResult(self, results): top = "top{}".format(str(self.args.top)) values = [item["predict"] for item in results] num_corrects = sum(values) percent = num_corrects * 100.0 / len(values) output = {} res = self.writeOneResult( values, num_corrects, "number_of_{}_corrects".format(top), "number" ) output[res["type"] + "_" + res["metric"]] = res res = self.writeOneResult( values, percent, "percent_of_{}_corrects".format(top), "percent" ) output[res["type"] + "_" + res["metric"]] = res if self.args.result_file: s = json.dumps(output, sort_keys=True, indent=2) with open(self.args.result_file, "w") as f: f.write(s) def compare(self): benchmark_data, dims_list = self.getData(self.args.benchmark_output) with open(self.args.labels, "r") as f: content = f.read() golden_lines = [ item.strip().split(",") for item in content.strip().split("\n") ] golden_data = [ {"index": int(item[0]), "label": item[1], "path": item[2]} for item in golden_lines ] if len(benchmark_data) != len(golden_data): idx = dims_list.index(len(golden_data)) benchmark_data = np.reshape( benchmark_data, (dims_list[idx], dims_list[idx + 1]) ) assert len(benchmark_data) == len( golden_data ), "Benchmark data has {} entries, ".format( len(benchmark_data) ) + "but golden data has {} entries".format( len(golden_data) ) def sort_key(elem): return elem["value"] for i in range(len(benchmark_data)): benchmark_one_entry = benchmark_data[i] golden_one_entry = golden_data[i] benchmark_result = [ { "index": j, "value": benchmark_one_entry[j], } for j in range(len(benchmark_one_entry)) ] benchmark_result.sort(reverse=True, key=sort_key) golden_one_entry["predict"] = ( 1 if golden_one_entry["index"] in [item["index"] for item in benchmark_result[: self.args.top]] else 0 ) self.writeResult(golden_data) if __name__ == "__main__": app = OutputCompare() app.compare()