benchmarking/frameworks/tflite/tflite.py (178 lines of code) (raw):

#!/usr/bin/env python ############################################################################## # Copyright 2017-present, Facebook, Inc. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. ############################################################################## from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import os import re from frameworks.framework_base import FrameworkBase from six import string_types class TFLiteFramework(FrameworkBase): def __init__(self, tempdir, args): super(TFLiteFramework, self).__init__(args) self.tempdir = os.path.join(tempdir, self.getName()) os.makedirs(self.tempdir, 0o777) def getName(self): return "tflite" def runBenchmark(self, info, benchmark, platform): output, output_files = super(TFLiteFramework, self).runBenchmark( info, benchmark, platform ) return output, output_files def verifyBenchmarkFile(self, benchmark, filename, is_post): assert "model" in benchmark, "Field model is missing in benchmark" assert ( "files" in benchmark["model"] ), "Field files is missing in benchmark[model]" assert ( "graph" in benchmark["model"]["files"] ), "Field graph is missing in benchmark[model][files]" assert "tests" in benchmark, "Field tests is missing in benchmark" for test in benchmark["tests"]: assert "warmup" in test, "Field warmup is missing in test" assert "iter" in test, "Field iter is missing in test" def composeRunCommand( self, commands, platform, programs, model, test, model_files, input_files, output_files, shared_libs, preprocess_files=None, main_command=False, ): cmds = super(TFLiteFramework, self).composeRunCommand( commands, platform, programs, model, test, model_files, input_files, output_files, shared_libs, preprocess_files, main_command, ) if cmds: return cmds # the following is for backward compatibility purpose input = None input_shape = None for layer in test["inputs"]: input = layer input_shape = ",".join(str(a) for a in test["inputs"][layer]["shapes"][0]) cmd = [ programs["program"], "--graph={}".format(model_files["graph"]), "--warmup_runs={}".format(test["warmup"]), "--num_runs={}".format(test["iter"]), "--input_layer={}".format(input), "--input_layer_shape={}".format(input_shape), ] cmd = [str(s) for s in cmd] return cmd def runOnPlatform(self, total_num, cmd, platform, platform_args, converter_class): output, meta = platform.runBenchmark( cmd, platform_args=platform_args, log_to_screen_only=True ) result = self._collectData(output) result["meta"] = meta return result def _collectData(self, output): if output is None: return False results = {} rows = output if isinstance(output, string_types): rows = output.split("\n") # only collect one data point for statistics # the actual run data should override the warmup data i = 0 while i < len(rows): i = self._collectNETLatency(results, rows, i) i = self._collectOperatorLatency(results, rows, i) i += 1 return results def _collectNETLatency(self, results, rows, i): row = rows[i] if row[:21] == "Running benchmark for": assert len(rows) > i + 1, "Valid row cannot be found" i = i + 1 data = rows[i] pattern = re.compile( r"^count=([\d|\.]+) first=([\d|\.]+) curr=([\d|\.]+) min=([\d|\.]+) max=([\d|\.]+) avg=([\d|\.]+) std=([\d|\.]+)" ) match = pattern.match(data) if match: r = { "count": int(match.group(1)), "min": float(match.group(4)), "max": float(match.group(5)), "avg": float(match.group(6)), "std": float(match.group(7)), } else: pattern = re.compile(r"^count=(\d+) curr=(\d+)") match = pattern.match(data) assert match, "No data is collected for {}".format(data) r = { "count": int(match.group(1)), "min": float(match.group(2)), "max": float(match.group(2)), "avg": float(match.group(2)), "std": 0, } results["NET latency"] = { "type": "NET", "unit": "us", "metric": "latency", "num_runs": r["count"], "summary": { "p0": r["min"], "p100": r["max"], "mean": r["avg"], "stdev": r["std"], }, } i = i + 1 return i def _collectOperatorLatency(self, results, rows, i): row = rows[i] if ( row[:71] == "============================== Run Order ==============================" ): i = i + 2 types_table = {} pattern = re.compile( r"\s+(\w+)\s+([\d|\.]+)\s+([\d|\.]+)\s+([\d|\.]+)\s+([\d|\.]+)%\s+([\d|\.]+)%\s+([\d|\.]+)\s+([\d|\.]+)\s+\[(.+)\]" ) while i < len(rows): row = rows[i] match = pattern.match(row) if not match: break type = match.group(9) kind = match.group(1) avg = float(match.group(4)) * 1000 results[type + " latency"] = { "type": type, "unit": "us", "metric": "latency", "summary": {"mean": avg}, } if kind in types_table: types_table[kind] += avg else: types_table[kind] = avg i = i + 1 # Write the accumulated operator types for k in types_table: v = types_table[k] results[k + " latency"] = { "type": k, "unit": "us", "metric": "latency", "summary": {"mean": v}, } return i