benchmarking/frameworks/generic/generic.py (29 lines of code) (raw):

#!/usr/bin/env python ############################################################################## # Copyright 2017-present, Facebook, Inc. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. ############################################################################## from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import os from frameworks.framework_base import FrameworkBase class GenericFramework(FrameworkBase): IDENTIFIER = "PyTorchObserver " def __init__(self, tempdir, args): super(GenericFramework, self).__init__(args) self.tempdir = os.path.join(tempdir, self.getName()) os.makedirs(self.tempdir, 0o777) def getName(self): return "generic" def runOnPlatform(self, total_num, cmd, platform, platform_args, converter): if converter is None: converter = { "name": "json_with_identifier_converter", "args": {"identifier": self.IDENTIFIER}, } converter_obj = self.converters[converter["name"]]() args = converter.get("args") results = [] output, meta = platform.runBenchmark(cmd, platform_args=platform_args) one_result, valid_run_idxs = converter_obj.collect(output, args) results.extend(one_result) metric = converter_obj.convert(results) metric["meta"] = meta return metric