libraries/python/aggregate_classification_results.py (115 lines of code) (raw):

#!/usr/bin/env python ############################################################################## # Copyright 2017-present, Facebook, Inc. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. ############################################################################## # This script aggregates test results from multiple runs and form # the final result from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import argparse import json import os import re parser = argparse.ArgumentParser(description="Aggregate output results") parser.add_argument( "--dir", required=True, help="The directory of all the json data are saved." ) parser.add_argument( "--limit", required=True, type=int, help="The directory of all the json data are saved.", ) parser.add_argument( "--metric-keyword", help="The keyword prefix each metric so that the harness can parse.", ) parser.add_argument( "--prefix", required=True, help="The prefix of the json data. The files are suffixed with a number " "and .txt", ) parser.add_argument("--result-file", help="Write the prediction result to a file.") class AggregateOutputs(object): def __init__(self): self.args = parser.parse_args() assert os.path.isdir(self.args.dir), "Directory {} doesn't exist".format( self.args.dir ) def _composeFilename(self, index): return os.path.join(self.args.dir, self.args.prefix + "_" + str(index) + ".txt") def _getOneOutput(self, index): filename = self._composeFilename(index) if not os.path.isfile(filename): print("File {} does not exist".format(filename)) return None with open(filename, "r") as f: content = json.load(f) return content def _collectAllOutputs(self): outputs = [] for index in range(self.args.limit): output = self._getOneOutput(index) if output is not None: outputs.append(output) return outputs def _aggregateOutputs(self, outputs): results = {} for one_output in outputs: for key in one_output: value = one_output[key] if key not in results: results[key] = value else: results[key]["values"].extend(value["values"]) pattern = re.compile(r"(\w+)_of_top(\d+)_corrects") # finally patch up the summary for res in results: one_result = results[res] one_result["type"] = one_result["type"] values = one_result["values"] match = pattern.match(one_result["metric"]) if not match: continue if match.group(1) == "number": data = sum(values) one_result["summary"] = { "num_runs": len(values), "p0": data, "p10": data, "p50": data, "p90": data, "p100": data, "mean": data, "std": 0, "MAD": 0, } elif match.group(1) == "percent": data = sum(values) * 100.0 / len(values) one_result["summary"] = { "num_runs": len(values), "p0": data, "p10": data, "p50": data, "p90": data, "p100": data, "mean": data, "std": 0, "MAD": 0, } one_result["metric"] = "total_" + one_result["metric"] # there may be too many values, only keep the summary if len(values) > 200: del one_result["values"] return results def aggregate(self): outputs = self._collectAllOutputs() results = self._aggregateOutputs(outputs) for key in results: result = results[key] s = json.dumps(result, sort_keys=True) if self.args.metric_keyword: s = self.args.metric_keyword + " " + s print(s) if self.args.result_file: s = json.dumps(results, sort_keys=True, indent=2) with open(self.args.result_file, "w") as f: f.write(s) if __name__ == "__main__": app = AggregateOutputs() app.aggregate()