# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pyre-unsafe

"""
Fuzz Testing for Thrift Services
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals

import argparse
import collections
import inspect
import json
import logging
import os
import pprint
import sys
import time

import six
import six.moves as sm
from six.moves.urllib.parse import urlparse

try:
    # pyre-fixme[21]: Could not find module `ServiceRouter`.
    from ServiceRouter import ConnConfigs, ServiceOptions, ServiceRouter  # @manual

    SR_AVAILABLE = True
except ImportError:
    SR_AVAILABLE = False

from thrift import Thrift
from thrift.protocol import TBinaryProtocol, TCompactProtocol, THeaderProtocol
from thrift.transport import TTransport, TSocket, TSSLSocket, THttpClient
from thrift.util import randomizer

if six.PY3:
    from importlib.machinery import SourceFileLoader

    def load_source(name, pathname):
        return SourceFileLoader(name, pathname).load_module()


else:
    import imp

    def load_source(name, pathname):
        return imp.load_source(name, pathname)


def positive_int(s) -> int:
    """Typechecker for positive integers"""
    try:
        n = int(s)
        if not n > 0:
            raise argparse.ArgumentTypeError("%s is not positive." % s)
        return n
    except ValueError:
        raise argparse.ArgumentTypeError("Cannot convert %s to an integer." % s)


def prob_float(s) -> float:
    """Typechecker for probability values"""
    try:
        x = float(s)
        if not 0 <= x <= 1:
            raise argparse.ArgumentTypeError("%s is not a valid probability." % x)
        return x
    except ValueError:
        raise argparse.ArgumentTypeError("Cannot convert %s to a float." % s)


class FuzzerConfiguration(object):
    """Container for Fuzzer configuration options"""

    argspec = {
        "allow_application_exceptions": {
            "description": "Do not flag TApplicationExceptions as errors",
            "type": bool,
            "flag": "-a",
            "argparse_kwargs": {"action": "store_const", "const": True},
            "default": False,
        },
        "compact": {
            "description": "Use TCompactProtocol",
            "type": bool,
            "flag": "-c",
            "argparse_kwargs": {"action": "store_const", "const": True},
            "default": False,
        },
        "constraints": {
            "description": "JSON Constraint dictionary",
            "type": str,
            "flag": "-Con",
            "default": {},
            "is_json": True,
        },
        "framed": {
            "description": "Use framed transport.",
            "type": bool,
            "flag": "-f",
            "argparse_kwargs": {"action": "store_const", "const": True},
            "default": False,
        },
        "functions": {
            "description": "Which functions to test. If excluded, test all",
            "type": str,
            "flag": "-F",
            "argparse_kwargs": {
                "nargs": "*",
            },
            "default": None,
        },
        "host": {
            "description": "The host and port to connect to",
            "type": str,
            "flag": "-h",
            "argparse_kwargs": {"metavar": "HOST[:PORT]"},
            "default": None,
        },
        "iterations": {
            "description": "Number of calls per method.",
            "type": positive_int,
            "flag": "-n",
            "attr_name": "n_iterations",
            "default": 1000,
        },
        "logfile": {
            "description": "File to write output logs.",
            "type": str,
            "flag": "-l",
            "default": None,
        },
        "loglevel": {
            "description": "Level of verbosity to write logs.",
            "type": str,
            "flag": "-L",
            "argparse_kwargs": {
                "choices": ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
            },
            "default": "INFO",
        },
        "service": {
            "description": "Path to file of Python service module.",
            "type": str,
            "flag": "-S",
            "attr_name": "service_path",
            "default": None,
        },
        "ssl": {
            "description": "Use SSL socket.",
            "type": bool,
            "flag": "-s",
            "argparse_kwargs": {"action": "store_const", "const": True},
            "default": False,
        },
        "unframed": {
            "description": "Use unframed transport.",
            "type": bool,
            "flag": "-U",
            "argparse_kwargs": {"action": "store_const", "const": True},
            "default": False,
        },
        "url": {
            "description": "The URL to connect to for HTTP transport",
            "type": str,
            "flag": "-u",
            "default": None,
        },
    }
    if SR_AVAILABLE:
        argspec["tier"] = {
            "description": "The SMC tier to connect to",
            "type": str,
            "flag": "-t",
            "default": None,
        }
        argspec["conn_configs"] = {
            "description": "ConnConfigs to use for ServiceRouter connection",
            "type": str,
            "flag": "-Conn",
            "default": {},
            "is_json": True,
        }
        argspec["service_options"] = {
            "description": "ServiceOptions to use for ServiceRouter connection",
            "type": str,
            "flag": "-SO",
            "default": {},
            "is_json": True,
        }

    def __init__(self, service=None):
        cls = self.__class__

        if service is not None:
            self.service = service

        parser = argparse.ArgumentParser(
            description="Fuzzer Configuration", add_help=False
        )
        parser.add_argument(
            "-C",
            "--config",
            dest="config_filename",
            help="JSON Configuration file. "
            "All settings can be specified as commandline "
            "args and config file settings. Commandline args "
            "override config file settings.",
        )

        parser.add_argument(
            "-?", "--help", action="help", help="Show this help message and exit."
        )

        for name, arg in six.iteritems(cls.argspec):
            kwargs = arg.get("argparse_kwargs", {})

            if kwargs.get("action", None) != "store_const":
                # Pass type to argparse. With store_const, type can be inferred
                kwargs["type"] = arg["type"]

            # If an argument is not passed, don't put a value in the namespace
            kwargs["default"] = argparse.SUPPRESS

            # Use the argument's description and default as a help message
            kwargs["help"] = "%s Default: %s" % (
                arg.get("description", ""),
                arg["default"],
            )

            kwargs["dest"] = arg.get("attr_name", name)

            if hasattr(self, kwargs["dest"]):
                # Attribute already assigned (e.g., service passed to __init__)
                continue

            parser.add_argument(arg["flag"], "--%s" % name, **kwargs)

            # Assign the default value to config namespace
            setattr(self, kwargs["dest"], arg["default"])

        args = parser.parse_args()

        # Read settings in config file
        self.__dict__.update(cls._config_file_settings(args))

        # Read settings in args
        self.__dict__.update(cls._args_settings(args))

        valid, message = self._validate_config()
        if not valid:
            print(message, file=sys.stderr)
            sys.exit(os.EX_USAGE)

    @classmethod
    def _try_parse_type(cls, name, type_, val):
        try:
            val = type_(val)
        except ValueError:
            raise TypeError(
                ("Expected type %s for setting %s, " "but got type %s (%s)")
                % (type_, name, type(val), val)
            )
        return val

    @classmethod
    def _try_parse(cls, name, arg, val):
        if arg.get("is_json", False):
            return val

        type_ = arg["type"]

        nargs = arg.get("argparse_kwargs", {}).get("nargs", None)

        if nargs is None:
            return cls._try_parse_type(name, type_, val)
        else:
            if not isinstance(val, list):
                raise TypeError(
                    (
                        "Expected list of length %s "
                        "for setting %s, but got type %s (%s)"
                    )
                    % (nargs, name, type(val), val)
                )
            ret = []
            for elem in val:
                ret.append(cls._try_parse_type(name, type_, elem))
            return ret

    @classmethod
    def _config_file_settings(cls, args):
        """Read settings from a configuration file"""
        if args.config_filename is None:
            return {}  # No config file
        if not os.path.exists(args.config_filename):
            raise OSError(
                os.EX_NOINPUT, "Config file does not exist: %s" % args.config_filename
            )
        with open(args.config_filename, "r") as fd:
            try:
                settings = json.load(fd)
            except ValueError as e:
                raise ValueError("Error parsing config file: %s" % e)

        # Make sure settings are well-formatted
        renamed_settings = {}
        if not isinstance(settings, dict):
            raise TypeError("Invalid config file. Top-level must be Object.")
        for name, val in six.iteritems(settings):
            if name not in cls.argspec:
                raise ValueError(("Unrecognized configuration " "option: %s") % name)
            arg = cls.argspec[name]
            val = cls._try_parse(name, arg, val)
            attr_name = arg.get("attr_name", name)
            renamed_settings[attr_name] = val
        return renamed_settings

    @classmethod
    def _args_settings(cls, args):
        """Read settings from the args namespace returned by argparse"""
        settings = {}
        for name, arg in six.iteritems(cls.argspec):
            attr_name = arg.get("attr_name", name)
            if not hasattr(args, attr_name):
                continue
            value = getattr(args, attr_name)
            if arg.get("is_json", False):
                settings[attr_name] = json.loads(value)
            else:
                settings[attr_name] = value
        return settings

    def __str__(self):
        return "Configuration(\n%s\n)" % pprint.pformat(self.__dict__)

    def load_service(self):
        if self.service is not None:
            if self.service_path is not None:
                raise ValueError(
                    "Cannot specify a service path when the "
                    "service is input programmatically"
                )
            # Service already loaded programmatically. Just load methods.
            self.service.load_methods()
            return self.service

        if self.service_path is None:
            raise ValueError("Error: No service specified")

        service_path = self.service_path

        if not os.path.exists(service_path):
            raise OSError("Service module does not exist: %s" % service_path)

        if not service_path.endswith(".py"):
            raise OSError("Service module is not a Python module: %s" % service_path)

        parent_path, service_filename = os.path.split(service_path)
        service_name = service_filename[:-3]  # Truncate extension

        logging.info("Service name: %s" % (service_name))

        parent_path = os.path.dirname(service_path)
        ttypes_path = os.path.join(parent_path, "ttypes.py")
        constants_path = os.path.join(parent_path, "constants.py")

        load_source("module", parent_path)
        ttypes_module = load_source("module.ttypes", ttypes_path)
        constants_module = load_source("module.constants", constants_path)
        service_module = load_source("module.%s" % (service_name), service_path)

        service = Service(ttypes_module, constants_module, service_module)
        service.load_methods()
        return service

    def _validate_config(self):
        # Verify there is one valid connection flag
        specified_flags = []
        connection_flags = FuzzerClient.connection_flags
        for flag in connection_flags:
            if hasattr(self, flag) and getattr(self, flag) is not None:
                specified_flags.append(flag)

        if not len(specified_flags) == 1:
            message = "Exactly one of [%s] must be specified. Got [%s]." % (
                (", ".join("--%s" % flag for flag in connection_flags)),
                (", ".join("--%s" % flag for flag in specified_flags)),
            )
            return False, message

        connection_method = specified_flags[0]
        self.connection_method = connection_method

        if connection_method == "url":
            if not (self.compact or self.framed or self.unframed):
                message = (
                    "A protocol (compact, framed, or unframed) "
                    "must be specified for HTTP Transport."
                )
                return False, message

        if connection_method in {"url", "host"}:
            if connection_method == "url":
                try:
                    url = urlparse(self.url)
                except Exception:
                    return False, "Unable to parse url %s" % self.url
                else:
                    connection_str = url[1]
            elif connection_method == "host":
                connection_str = self.host
            if ":" in connection_str:
                # Get the string after the colon
                port = connection_str[connection_str.index(":") + 1 :]
                try:
                    int(port)
                except ValueError:
                    message = "Port is not an integer: %s" % port
                    return False, message

        return True, None


class Service(object):
    """Wrapper for a thrift service"""

    def __init__(self, ttypes_module, constants_module, service_module):
        self.ttypes = ttypes_module
        self.constants = constants_module
        self.service = service_module
        self.methods = None

    def __str__(self):
        return "Service(%s)" % self.service.__name__

    def load_methods(self, exclude_ifaces=None):
        """Load a service's methods.

        If exclude_ifaces is not None, it should be a collection and only
        methods from thrift interfaces not included in that collection will
        be considered."""

        exclude_ifaces = exclude_ifaces or []

        pred = inspect.isfunction if six.PY3 else inspect.ismethod

        methods = {}
        exclude_methods = []

        for klass in exclude_ifaces:
            exclude_methods.extend(inspect.getmembers(klass, predicate=pred))

        klass_methods = inspect.getmembers(self.service.Iface, predicate=pred)

        for method_name, method in klass_methods:
            if (method_name, method) in exclude_methods:
                continue

            module = inspect.getmodule(method)

            args = getattr(module, method_name + "_args", None)
            if args is None:
                continue
            result = getattr(module, method_name + "_result", None)

            thrift_exceptions = []
            if result is not None:
                for res_spec in result.thrift_spec:
                    if res_spec is None:
                        continue
                    if res_spec[2] != "success":
                        # This is an exception return type
                        spec_args = res_spec[3]
                        exception_type = spec_args[0]
                        thrift_exceptions.append(exception_type)

            methods[method_name] = {
                "args_class": args,
                "result_spec": result,
                "thrift_exceptions": tuple(thrift_exceptions),
            }

        self.methods = methods

    @property
    def client_class(self):
        return self.service.Client

    def get_methods(self, include=None):
        """Get a dictionary of methods provided by the service.

        If include is not None, it should be a collection and only
        the method names in that collection will be included."""

        if self.methods is None:
            raise ValueError(
                "Service.load_methods must be " "called before Service.get_methods"
            )

        if include is None:
            return self.methods

        included_methods = {}
        for method_name in include:
            if method_name not in self.methods:
                raise NameError("Function does not exist: %s" % method_name)
            included_methods[method_name] = self.methods[method_name]

        return included_methods


class FuzzerClient(object):
    """Client wrapper used to make calls based on configuration settings"""

    connection_flags = ["host", "url", "tier"]
    default_port = 9090

    def __init__(self, config, client_class):
        self.config = config
        self.client_class = client_class

    def _get_client_by_transport(self, config, transport, socket=None):
        # Create the protocol and client
        if config.compact:
            protocol = TCompactProtocol.TCompactProtocol(transport)
        # No explicit option about protocol is specified. Try to infer.
        elif config.framed or config.unframed:
            protocol = TBinaryProtocol.TBinaryProtocolAccelerated(transport)
        elif socket is not None:
            protocol = THeaderProtocol.THeaderProtocol(socket)
            transport = protocol.trans
        else:
            raise ValueError("No protocol specified for HTTP Transport")
        transport.open()
        self._transport = transport

        client = self.client_class(protocol)
        return client

    def _parse_host_port(self, value, default_port):
        parts = value.rsplit(":", 1)
        if len(parts) == 1:
            return (parts[0], default_port)
        else:
            # FuzzerConfiguration ensures parts[1] is an int
            return (parts[0], int(parts[1]))

    def _get_client_by_host(self):
        config = self.config
        host, port = self._parse_host_port(config.host, self.default_port)
        socket = (
            TSSLSocket.TSSLSocket(host, port)
            if config.ssl
            else TSocket.TSocket(host, port)
        )
        if config.framed:
            transport = TTransport.TFramedTransport(socket)
        else:
            transport = TTransport.TBufferedTransport(socket)
        return self._get_client_by_transport(config, transport, socket=socket)

    def _get_client_by_url(self):
        config = self.config
        url = urlparse(config.url)
        host, port = self._parse_host_port(url[1], 80)
        transport = THttpClient.THttpClient(config.url)
        return self._get_client_by_transport(config, transport)

    def _get_client_by_tier(self):
        """Get a client that uses ServiceRouter"""
        config = self.config
        serviceRouter = ServiceRouter()

        overrides = ConnConfigs()
        for key, val in six.iteritems(config.conn_configs):
            key = six.binary_type(key)
            val = six.binary_type(val)
            overrides[key] = val

        sr_options = ServiceOptions()
        for key, val in six.iteritems(config.service_options):
            key = six.binary_type(key)
            if not isinstance(val, list):
                raise TypeError(
                    "Service option %s expected list; got %s (%s)"
                    % (key, val, type(val))
                )
            val = [six.binary_type(elem) for elem in val]
            sr_options[key] = val

        service_name = config.tier

        # Obtain a normal client connection using SR2
        client = serviceRouter.getClient2(
            self.client_class, service_name, sr_options, overrides
        )

        if client is None:
            raise NameError("Failed to lookup host for tier %s" % service_name)

        return client

    def _get_client(self):
        if self.config.connection_method == "host":
            client = self._get_client_by_host()
        elif self.config.connection_method == "url":
            client = self._get_client_by_url()
        elif self.config.connection_method == "tier":
            client = self._get_client_by_tier()
        else:
            raise NameError(
                "Unknown connection type: %s" % self.config.connection_method
            )
        return client

    def _close_client(self):
        if self.config.connection_method in {"host", "url"}:
            self._transport.close()

    def __enter__(self):
        self.client = self._get_client()
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self._close_client()
        self.client = None

    def reset(self):
        self._close_client()
        try:
            self.client = self._get_client()
            return True
        except TTransport.TTransportException as e:
            logging.error("Unable to reset connection: %r" % e)
            return False

    def make_call(self, method_name, kwargs, is_oneway=False):
        method = getattr(self.client, method_name)
        ret = method(**kwargs)

        if is_oneway:
            self.reset()

        return ret


class Timer(object):
    def __init__(self, aggregator, category, action):
        self.aggregator = aggregator
        self.category = category
        self.action = action

    def __enter__(self):
        self.start_time = time.time()

    def __exit__(self, exc_type, exc_value, traceback):
        end_time = time.time()
        time_elapsed = end_time - self.start_time
        self.aggregator.add(self.category, self.action, time_elapsed)


class TimeAggregator(object):
    def __init__(self):
        self.total_time = collections.defaultdict(
            lambda: collections.defaultdict(float)
        )

    def time(self, category, action):
        return Timer(self, category, action)

    def add(self, category, action, time_elapsed):
        self.total_time[category][action] += time_elapsed

    def summarize(self):
        max_category_name_length = max(len(name) for name in self.total_time)
        max_action_name_length = max(
            max(len(action_name) for action_name in self.total_time[name])
            for name in self.total_time
        )
        category_format = "%%%ds: %%s" % max_category_name_length
        action_format = "%%%ds: %%4.3fs" % max_action_name_length

        category_summaries = []
        for category_name, category_actions in sorted(self.total_time.items()):
            timing_items = []
            for action_name, action_time in sorted(category_actions.items()):
                timing_items.append(action_format % (action_name, action_time))
            all_actions = " | ".join(timing_items)
            category_summaries.append(category_format % (category_name, all_actions))
        summaries = "\n".join(category_summaries)
        logging.info("Timing Summary:\n%s" % summaries)


class FuzzTester(object):
    summary_interval = 1  # Seconds between summary logs

    class Result:
        Success = 0
        TransportException = 1
        ApplicationException = 2
        UserDefinedException = 3
        OtherException = 4
        Crash = 5

    def __init__(self, config):
        self.config = config
        self.service = None
        self.randomizer = None
        self.client = None

    def start_logging(self):
        logfile = self.config.logfile
        if self.config.logfile is None:
            logfile = "/dev/null"
        log_level = getattr(logging, self.config.loglevel)

        datefmt = "%Y-%m-%d %H:%M:%S"
        fmt = "[%(asctime)s] [%(levelname)s] %(message)s"

        if logfile == "stdout":
            logging.basicConfig(stream=sys.stdout, level=log_level)
        else:
            logging.basicConfig(filename=self.config.logfile, level=log_level)

        log_handler = logging.getLogger().handlers[0]
        log_handler.setFormatter(logging.Formatter(fmt, datefmt=datefmt))

    def start_timing(self):
        self.timer = TimeAggregator()
        self.next_summary_time = time.time() + self.__class__.summary_interval

    def _call_string(self, method_name, kwargs):
        kwarg_str = ", ".join("%s=%s" % (k, v) for k, v in six.iteritems(kwargs))
        return "%s(%s)" % (method_name, kwarg_str)

    def run_test(
        self, method_name, kwargs, expected_output, is_oneway, thrift_exceptions
    ):
        """
        Make an RPC with given arguments and check for exceptions.
        """
        try:
            with self.timer.time(method_name, "Thrift"):
                self.client.make_call(method_name, kwargs, is_oneway)
        except thrift_exceptions as e:
            self.record_result(method_name, FuzzTester.Result.UserDefinedException)
            if self.config.loglevel == "DEBUG":
                with self.timer.time(method_name, "Logging"):
                    logging.debug("Got thrift exception: %r" % e)
                    logging.debug(
                        "Exception thrown by call: %s"
                        % (self._call_string(method_name, kwargs))
                    )

        except Thrift.TApplicationException as e:
            self.record_result(method_name, FuzzTester.Result.ApplicationException)
            if self.config.allow_application_exceptions:
                if self.config.loglevel == "DEBUG":
                    with self.timer.time(method_name, "Logging"):
                        logging.debug("Got TApplication exception %s" % e)
                        logging.debug(
                            "Exception thrown by call: %s"
                            % (self._call_string(method_name, kwargs))
                        )
            else:
                with self.timer.time(method_name, "Logging"):
                    self.n_exceptions += 1
                    logging.error("Got application exception: %s" % e)
                    logging.error(
                        "Offending call: %s" % (self._call_string(method_name, kwargs))
                    )

        except TTransport.TTransportException as e:
            self.n_exceptions += 1

            with self.timer.time(method_name, "Logging"):
                logging.error("Got TTransportException: (%s, %r)" % (e, e))
                logging.error(
                    "Offending call: %s" % (self._call_string(method_name, kwargs))
                )

            if "errno = 111: Connection refused" in e.args[0]:
                # Unable to connect to server - server may be down
                self.record_result(method_name, FuzzTester.Result.Crash)
                return False

            if not self.client.reset():
                logging.error("Inferring server crash.")
                self.record_result(method_name, FuzzTester.Result.Crash)
                return False

            self.record_result(method_name, FuzzTester.Result.TransportException)

        except Exception as e:
            self.record_result(method_name, FuzzTester.Result.OtherException)
            with self.timer.time(method_name, "Logging"):
                self.n_exceptions += 1
                logging.error("Got exception %s (%r)" % (e, e))
                logging.error(
                    "Offending call: %s" % (self._call_string(method_name, kwargs))
                )
                if hasattr(self, "previous_kwargs"):
                    logging.error(
                        "Previous call: %s"
                        % (self._call_string(method_name, self.previous_kwargs))
                    )

        else:
            self.record_result(method_name, FuzzTester.Result.Success)
            if self.config.loglevel == "DEBUG":
                with self.timer.time(method_name, "Logging"):
                    logging.debug(
                        "Successful call: %s" % (self._call_string(method_name, kwargs))
                    )
        finally:
            self.n_tests += 1

        return True

    def fuzz_kwargs(self, method_name, n_iterations):
        # For now, just yield n random sets of args
        # In future versions, fuzz fields more methodically based
        # on feedback and seeds
        for _ in sm.xrange(n_iterations):
            with self.timer.time(method_name, "Randomizing"):
                method_randomizer = self.method_randomizers[method_name]
                args_struct = method_randomizer.generate()
            if args_struct is None:
                logging.error("Unable to produce valid arguments for %s" % method_name)
            else:
                kwargs = args_struct.__dict__  # Get members of args struct
                yield kwargs

    def get_method_randomizers(self, methods, constraints):
        """Create a StructRandomizer for each method"""
        state = randomizer.RandomizerState()
        method_randomizers = {}

        state.push_type_constraints(constraints)

        for method_name in methods:
            method_constraints = constraints.get(method_name, {})
            args_class = methods[method_name]["args_class"]

            # Create a spec_args tuple for the method args struct type
            randomizer_spec_args = (
                args_class,
                args_class.thrift_spec,
                False,  # isUnion
            )

            method_randomizer = state.get_randomizer(
                Thrift.TType.STRUCT, randomizer_spec_args, method_constraints
            )
            method_randomizers[method_name] = method_randomizer

        return method_randomizers

    def _split_key(self, key):
        """Split a constraint rule key such as a.b|c into ['a', 'b', '|c']
        Dots separate hierarchical field names and property names

        Pipes indicate a type name and hashes indicate a field name,
        though these rules are not yet supported.
        """
        components = []
        start_idx = 0
        cur_idx = 0
        while cur_idx < len(key):
            if cur_idx != start_idx and key[cur_idx] in {".", "|", "#"}:
                components.append(key[start_idx:cur_idx])
                start_idx = cur_idx
                if key[cur_idx] == ".":
                    start_idx += 1
                cur_idx = start_idx
            else:
                cur_idx += 1
        components.append(key[start_idx:])
        return components

    def preprocess_constraints(self, source_constraints):
        """
        The constraints dictionary can have any key
        that follows the following format:

        method_name[.arg_name][.field_name ...].property_name

        The values in the dictionary can be nested such that inner field
        names are subfields of the outer scope, and inner type rules are
        applied only to subvalues of the out scope.

        After preprocessing, each dictionary level should have exactly one
        method name, field name, or property name as its key.

        Any strings of identifiers are converted into the nested dictionary
        structure. For example, the constraint set:

        {'my_method.my_field.distribution': 'uniform(0,100)'}

        Will be preprocessed to:

        {'my_method':
          {'my_field':
             {'distribution': 'uniform(0, 100)'}
          }
        }
        """
        constraints = {}
        scope_path = []

        def add_constraint(rule):
            walk_scope = constraints
            for key in scope_path[:-1]:
                if key not in walk_scope:
                    walk_scope[key] = {}
                walk_scope = walk_scope[key]
            walk_scope[scope_path[-1]] = rule

        def add_constraints_from_dict(d):
            for key, rule in six.iteritems(d):
                key_components = self._split_key(key)
                scope_path.extend(key_components)
                if isinstance(rule, dict):
                    add_constraints_from_dict(rule)
                else:
                    add_constraint(rule)
                scope_path[-len(key_components) :] = []

        add_constraints_from_dict(source_constraints)
        return constraints

    def start_result_counters(self):
        """Create result counters. The counters object is a dict that maps
        a method name to a counter of FuzzTest.Results
        """
        self.result_counters = collections.defaultdict(collections.Counter)

    def record_result(self, method_name, result):
        self.result_counters[method_name][result] += 1

    def log_result_summary(self, method_name):
        if time.time() >= self.next_summary_time:
            results = []
            for name, val in six.iteritems(vars(FuzzTester.Result)):
                if name.startswith("_"):
                    continue
                count = self.result_counters[method_name][val]
                if count > 0:
                    results.append((name, count))
            results.sort()
            logging.info(
                "%s count: {%s}"
                % (method_name, ", ".join("%s: %d" % r for r in results))
            )

            interval = self.__class__.summary_interval
            # Determine how many full intervals have passed between
            # self.next_summary_time (the scheduled time for this summary) and
            # the time the summary is actually completed.
            intervals_passed = int((time.time() - self.next_summary_time) / interval)
            # Schedule the next summary for the first interval that has not yet
            # fully passed
            self.next_summary_time += interval * (intervals_passed + 1)

    def run(self):
        self.start_logging()
        self.start_timing()
        self.start_result_counters()

        logging.info("Starting Fuzz Tester")
        logging.info(str(self.config))

        self.service = self.config.load_service()

        client_class = self.service.client_class

        methods = self.service.get_methods(self.config.functions)
        constraints = self.preprocess_constraints(self.config.constraints)
        self.method_randomizers = self.get_method_randomizers(methods, constraints)

        logging.info("Fuzzing methods: %s" % methods.keys())

        with FuzzerClient(self.config, client_class) as self.client:
            for method_name, spec in six.iteritems(methods):
                result_spec = spec.get("result_spec", None)
                thrift_exceptions = spec["thrift_exceptions"]
                is_oneway = result_spec is None
                logging.info("Fuzz testing method %s" % (method_name))
                self.n_tests = 0
                self.n_exceptions = 0
                did_crash = False
                for kwargs in self.fuzz_kwargs(method_name, self.config.n_iterations):
                    if not self.run_test(
                        method_name, kwargs, None, is_oneway, thrift_exceptions
                    ):
                        did_crash = True
                        break
                    self.log_result_summary(method_name)
                    self.previous_kwargs = kwargs

                if did_crash:
                    logging.error(
                        ("Method %s caused the " "server to crash.") % (method_name)
                    )
                    break
                else:
                    logging.info(
                        ("Method %s raised unexpected " "exceptions in %d/%d tests.")
                        % (method_name, self.n_exceptions, self.n_tests)
                    )

        self.timer.summarize()


def run_fuzzer(config) -> None:
    fuzzer = FuzzTester(config)
    fuzzer.run()


def fuzz_service(service: Service, ttypes, constants) -> None:
    """Run the tester with required modules input programmatically"""
    service = Service(ttypes, constants, service)
    config = FuzzerConfiguration(service)
    run_fuzzer(config)


if __name__ == "__main__":
    config = FuzzerConfiguration()
    run_fuzzer(config)
