benchmarking/frameworks/generic/generic.py (29 lines of code) (raw):
#!/usr/bin/env python
##############################################################################
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
##############################################################################
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
import os
from frameworks.framework_base import FrameworkBase
class GenericFramework(FrameworkBase):
IDENTIFIER = "PyTorchObserver "
def __init__(self, tempdir, args):
super(GenericFramework, self).__init__(args)
self.tempdir = os.path.join(tempdir, self.getName())
os.makedirs(self.tempdir, 0o777)
def getName(self):
return "generic"
def runOnPlatform(self, total_num, cmd, platform, platform_args, converter):
if converter is None:
converter = {
"name": "json_with_identifier_converter",
"args": {"identifier": self.IDENTIFIER},
}
converter_obj = self.converters[converter["name"]]()
args = converter.get("args")
results = []
output, meta = platform.runBenchmark(cmd, platform_args=platform_args)
one_result, valid_run_idxs = converter_obj.collect(output, args)
results.extend(one_result)
metric = converter_obj.convert(results)
metric["meta"] = meta
return metric