#!/usr/bin/env python

##############################################################################
# Copyright 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
##############################################################################

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import os
import re

from frameworks.framework_base import FrameworkBase
from six import string_types


class TFLiteFramework(FrameworkBase):
    def __init__(self, tempdir, args):
        super(TFLiteFramework, self).__init__(args)
        self.tempdir = os.path.join(tempdir, self.getName())
        os.makedirs(self.tempdir, 0o777)

    def getName(self):
        return "tflite"

    def runBenchmark(self, info, benchmark, platform):
        output, output_files = super(TFLiteFramework, self).runBenchmark(
            info, benchmark, platform
        )
        return output, output_files

    def verifyBenchmarkFile(self, benchmark, filename, is_post):
        assert "model" in benchmark, "Field model is missing in benchmark"
        assert (
            "files" in benchmark["model"]
        ), "Field files is missing in benchmark[model]"
        assert (
            "graph" in benchmark["model"]["files"]
        ), "Field graph is missing in benchmark[model][files]"

        assert "tests" in benchmark, "Field tests is missing in benchmark"
        for test in benchmark["tests"]:
            assert "warmup" in test, "Field warmup is missing in test"
            assert "iter" in test, "Field iter is missing in test"

    def composeRunCommand(
        self,
        commands,
        platform,
        programs,
        model,
        test,
        model_files,
        input_files,
        output_files,
        shared_libs,
        preprocess_files=None,
        main_command=False,
    ):
        cmds = super(TFLiteFramework, self).composeRunCommand(
            commands,
            platform,
            programs,
            model,
            test,
            model_files,
            input_files,
            output_files,
            shared_libs,
            preprocess_files,
            main_command,
        )
        if cmds:
            return cmds

        # the following is for backward compatibility purpose
        input = None
        input_shape = None
        for layer in test["inputs"]:
            input = layer
            input_shape = ",".join(str(a) for a in test["inputs"][layer]["shapes"][0])
        cmd = [
            programs["program"],
            "--graph={}".format(model_files["graph"]),
            "--warmup_runs={}".format(test["warmup"]),
            "--num_runs={}".format(test["iter"]),
            "--input_layer={}".format(input),
            "--input_layer_shape={}".format(input_shape),
        ]

        cmd = [str(s) for s in cmd]
        return cmd

    def runOnPlatform(self, total_num, cmd, platform, platform_args, converter_class):
        output, meta = platform.runBenchmark(
            cmd, platform_args=platform_args, log_to_screen_only=True
        )
        result = self._collectData(output)
        result["meta"] = meta
        return result

    def _collectData(self, output):
        if output is None:
            return False
        results = {}
        rows = output
        if isinstance(output, string_types):
            rows = output.split("\n")
        # only collect one data point for statistics
        # the actual run data should override the warmup data
        i = 0
        while i < len(rows):
            i = self._collectNETLatency(results, rows, i)
            i = self._collectOperatorLatency(results, rows, i)
            i += 1
        return results

    def _collectNETLatency(self, results, rows, i):
        row = rows[i]
        if row[:21] == "Running benchmark for":
            assert len(rows) > i + 1, "Valid row cannot be found"
            i = i + 1
            data = rows[i]
            pattern = re.compile(
                r"^count=([\d|\.]+) first=([\d|\.]+) curr=([\d|\.]+) min=([\d|\.]+) max=([\d|\.]+) avg=([\d|\.]+) std=([\d|\.]+)"
            )
            match = pattern.match(data)
            if match:
                r = {
                    "count": int(match.group(1)),
                    "min": float(match.group(4)),
                    "max": float(match.group(5)),
                    "avg": float(match.group(6)),
                    "std": float(match.group(7)),
                }
            else:
                pattern = re.compile(r"^count=(\d+) curr=(\d+)")
                match = pattern.match(data)
                assert match, "No data is collected for {}".format(data)

                r = {
                    "count": int(match.group(1)),
                    "min": float(match.group(2)),
                    "max": float(match.group(2)),
                    "avg": float(match.group(2)),
                    "std": 0,
                }
            results["NET latency"] = {
                "type": "NET",
                "unit": "us",
                "metric": "latency",
                "num_runs": r["count"],
                "summary": {
                    "p0": r["min"],
                    "p100": r["max"],
                    "mean": r["avg"],
                    "stdev": r["std"],
                },
            }
            i = i + 1
        return i

    def _collectOperatorLatency(self, results, rows, i):
        row = rows[i]
        if (
            row[:71]
            == "============================== Run Order =============================="
        ):
            i = i + 2
            types_table = {}
            pattern = re.compile(
                r"\s+(\w+)\s+([\d|\.]+)\s+([\d|\.]+)\s+([\d|\.]+)\s+([\d|\.]+)%\s+([\d|\.]+)%\s+([\d|\.]+)\s+([\d|\.]+)\s+\[(.+)\]"
            )
            while i < len(rows):
                row = rows[i]
                match = pattern.match(row)
                if not match:
                    break
                type = match.group(9)
                kind = match.group(1)
                avg = float(match.group(4)) * 1000
                results[type + " latency"] = {
                    "type": type,
                    "unit": "us",
                    "metric": "latency",
                    "summary": {"mean": avg},
                }
                if kind in types_table:
                    types_table[kind] += avg
                else:
                    types_table[kind] = avg
                i = i + 1
            # Write the accumulated operator types
            for k in types_table:
                v = types_table[k]
                results[k + " latency"] = {
                    "type": k,
                    "unit": "us",
                    "metric": "latency",
                    "summary": {"mean": v},
                }
        return i
