benchmarking/frameworks/pytorch/pytorch.py (43 lines of code) (raw):
#!/usr/bin/env python
##############################################################################
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
##############################################################################
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from frameworks.caffe2.caffe2 import Caffe2Framework
from utils.custom_logger import getLogger
class PytorchFramework(Caffe2Framework):
IDENTIFIER = "PyTorchObserver "
NET = "NET"
def getName(self):
return "pytorch"
def runOnPlatform(self, total_num, cmd, platform, platform_args, converter):
if converter is None:
converter = {
"name": "json_with_identifier_converter",
"args": {"identifier": self.IDENTIFIER},
}
converter_obj = self.converters[converter["name"]]()
args = converter.get("args")
results = []
num = 0
# emulate do...while... loop
while True:
output, meta = platform.runBenchmark(cmd, platform_args=platform_args)
one_result, valid_run_idxs = converter_obj.collect(output, args)
valid_run_idxs = [num + idx for idx in valid_run_idxs]
num += len(valid_run_idxs)
results.extend(one_result)
if num < total_num:
num_items = len(valid_run_idxs)
if num_items > 0:
getLogger().info(
"%d items collected, Still missing %d "
"runs. Collect again." % (num_items, total_num - num)
)
continue
else:
getLogger().info("No new items collected, " "finish collecting...")
elif total_num >= 0 and num > total_num:
# if collect more than the needed number, get the
# latest entries. This may happen when the data in
# the previous runs are not cleared. e.g. on some
# android 5 devices. Or, it may happen when multiple
# runs are needed to collect the desired number of
# iterations
results = results[valid_run_idxs[num - total_num] :]
break
metric = converter_obj.convert(results)
metric["meta"] = meta
return metric