libraries/python/aggregate_classification_results.py (115 lines of code) (raw):
#!/usr/bin/env python
##############################################################################
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
##############################################################################
# This script aggregates test results from multiple runs and form
# the final result
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import argparse
import json
import os
import re
parser = argparse.ArgumentParser(description="Aggregate output results")
parser.add_argument(
"--dir", required=True, help="The directory of all the json data are saved."
)
parser.add_argument(
"--limit",
required=True,
type=int,
help="The directory of all the json data are saved.",
)
parser.add_argument(
"--metric-keyword",
help="The keyword prefix each metric so that the harness can parse.",
)
parser.add_argument(
"--prefix",
required=True,
help="The prefix of the json data. The files are suffixed with a number "
"and .txt",
)
parser.add_argument("--result-file", help="Write the prediction result to a file.")
class AggregateOutputs(object):
def __init__(self):
self.args = parser.parse_args()
assert os.path.isdir(self.args.dir), "Directory {} doesn't exist".format(
self.args.dir
)
def _composeFilename(self, index):
return os.path.join(self.args.dir, self.args.prefix + "_" + str(index) + ".txt")
def _getOneOutput(self, index):
filename = self._composeFilename(index)
if not os.path.isfile(filename):
print("File {} does not exist".format(filename))
return None
with open(filename, "r") as f:
content = json.load(f)
return content
def _collectAllOutputs(self):
outputs = []
for index in range(self.args.limit):
output = self._getOneOutput(index)
if output is not None:
outputs.append(output)
return outputs
def _aggregateOutputs(self, outputs):
results = {}
for one_output in outputs:
for key in one_output:
value = one_output[key]
if key not in results:
results[key] = value
else:
results[key]["values"].extend(value["values"])
pattern = re.compile(r"(\w+)_of_top(\d+)_corrects")
# finally patch up the summary
for res in results:
one_result = results[res]
one_result["type"] = one_result["type"]
values = one_result["values"]
match = pattern.match(one_result["metric"])
if not match:
continue
if match.group(1) == "number":
data = sum(values)
one_result["summary"] = {
"num_runs": len(values),
"p0": data,
"p10": data,
"p50": data,
"p90": data,
"p100": data,
"mean": data,
"std": 0,
"MAD": 0,
}
elif match.group(1) == "percent":
data = sum(values) * 100.0 / len(values)
one_result["summary"] = {
"num_runs": len(values),
"p0": data,
"p10": data,
"p50": data,
"p90": data,
"p100": data,
"mean": data,
"std": 0,
"MAD": 0,
}
one_result["metric"] = "total_" + one_result["metric"]
# there may be too many values, only keep the summary
if len(values) > 200:
del one_result["values"]
return results
def aggregate(self):
outputs = self._collectAllOutputs()
results = self._aggregateOutputs(outputs)
for key in results:
result = results[key]
s = json.dumps(result, sort_keys=True)
if self.args.metric_keyword:
s = self.args.metric_keyword + " " + s
print(s)
if self.args.result_file:
s = json.dumps(results, sort_keys=True, indent=2)
with open(self.args.result_file, "w") as f:
f.write(s)
if __name__ == "__main__":
app = AggregateOutputs()
app.aggregate()