# Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
import boto3
import logging
import os
import re
import signal
import subprocess
import tfs_utils

from contextlib import contextmanager

logging.basicConfig(
    format="%(process)d %(asctime)s %(levelname)-8s %(message)s", force=True, level=logging.INFO
)
log = logging.getLogger(__name__)

JS_PING = "js_content tensorflowServing.ping"
JS_INVOCATIONS = "js_content tensorflowServing.invocations"
GUNICORN_PING = "proxy_pass http://gunicorn_upstream/ping"
GUNICORN_INVOCATIONS = "proxy_pass http://gunicorn_upstream/invocations"
CODE_DIR = (
    "/opt/ml/code"
    if os.environ.get("SAGEMAKER_MULTI_MODEL", "False").lower() == "true"
    else "/opt/ml/model/code"
)
PYTHON_LIB_PATH = os.path.join(CODE_DIR, "lib")
REQUIREMENTS_PATH = os.path.join(CODE_DIR, "requirements.txt")
INFERENCE_PATH = os.path.join(CODE_DIR, "inference.py")


class ServiceManager(object):
    def __init__(self):
        self._state = "initializing"
        self._nginx = None
        self._tfs = []
        self._gunicorn = None
        self._gunicorn_command = None
        self._gunicorn_env = None
        self._enable_python_service = False
        self._tfs_version = os.environ.get("SAGEMAKER_TFS_VERSION", "1.13")
        self._nginx_http_port = os.environ.get("SAGEMAKER_BIND_TO_PORT", "8080")
        self._nginx_loglevel = os.environ.get("SAGEMAKER_TFS_NGINX_LOGLEVEL", "error")
        self._tfs_default_model_name = os.environ.get("SAGEMAKER_TFS_DEFAULT_MODEL_NAME", "None")
        self._sagemaker_port_range = os.environ.get("SAGEMAKER_SAFE_PORT_RANGE", None)
        self._gunicorn_workers = os.environ.get("SAGEMAKER_GUNICORN_WORKERS", 1)
        self._gunicorn_threads = os.environ.get("SAGEMAKER_GUNICORN_THREADS", 1)
        self._gunicorn_loglevel = os.environ.get("SAGEMAKER_GUNICORN_LOGLEVEL", "info")
        self._tfs_config_path = "/sagemaker/model-config.cfg"
        self._tfs_batching_config_path = "/sagemaker/batching-config.cfg"

        _enable_batching = os.environ.get("SAGEMAKER_TFS_ENABLE_BATCHING", "false").lower()
        _enable_multi_model_endpoint = os.environ.get("SAGEMAKER_MULTI_MODEL", "false").lower()
        # Use this to specify memory that is needed to initialize CUDA/cuDNN and other GPU libraries
        self._tfs_gpu_margin = float(os.environ.get("SAGEMAKER_TFS_FRACTIONAL_GPU_MEM_MARGIN", 0.2))
        self._tfs_instance_count = int(os.environ.get("SAGEMAKER_TFS_INSTANCE_COUNT", 1))
        self._tfs_wait_time_seconds = int(
            os.environ.get("SAGEMAKER_TFS_WAIT_TIME_SECONDS", 55 // self._tfs_instance_count)
        )
        self._tfs_inter_op_parallelism = os.environ.get("SAGEMAKER_TFS_INTER_OP_PARALLELISM", 0)
        self._tfs_intra_op_parallelism = os.environ.get("SAGEMAKER_TFS_INTRA_OP_PARALLELISM", 0)
        self._gunicorn_worker_class = os.environ.get("SAGEMAKER_GUNICORN_WORKER_CLASS", "gevent")
        self._gunicorn_timeout_seconds = int(
            os.environ.get("SAGEMAKER_GUNICORN_TIMEOUT_SECONDS", 30)
        )
        self._nginx_proxy_read_timeout_seconds = int(
            os.environ.get("SAGEMAKER_NGINX_PROXY_READ_TIMEOUT_SECONDS", 60)
        )

        # Nginx proxy read timeout should not be less than the GUnicorn timeout. If it is, this
        # can result in upstream time out errors.
        if self._gunicorn_timeout_seconds > self._nginx_proxy_read_timeout_seconds:
            log.info(
                "GUnicorn timeout was higher than Nginx proxy read timeout."
                " Setting Nginx proxy read timeout from {} seconds to {} seconds"
                " to match GUnicorn timeout.".format(
                    self._nginx_proxy_read_timeout_seconds, self._gunicorn_timeout_seconds
                )
            )
            self._nginx_proxy_read_timeout_seconds = self._gunicorn_timeout_seconds

        if os.environ.get("OMP_NUM_THREADS") is None:
            os.environ["OMP_NUM_THREADS"] = "1"

        if _enable_multi_model_endpoint not in ["true", "false"]:
            raise ValueError("SAGEMAKER_MULTI_MODEL must be 'true' or 'false'")
        self._tfs_enable_multi_model_endpoint = _enable_multi_model_endpoint == "true"

        self._need_python_service()
        log.info("PYTHON SERVICE: {}".format(str(self._enable_python_service)))

        if _enable_batching not in ["true", "false"]:
            raise ValueError("SAGEMAKER_TFS_ENABLE_BATCHING must be 'true' or 'false'")
        self._tfs_enable_batching = _enable_batching == "true"

        if _enable_multi_model_endpoint not in ["true", "false"]:
            raise ValueError("SAGEMAKER_MULTI_MODEL must be 'true' or 'false'")
        self._tfs_enable_multi_model_endpoint = _enable_multi_model_endpoint == "true"

        self._use_gunicorn = self._enable_python_service or self._tfs_enable_multi_model_endpoint

        if self._sagemaker_port_range is not None:
            parts = self._sagemaker_port_range.split("-")
            low = int(parts[0])
            hi = int(parts[1])
            self._tfs_grpc_ports = []
            self._tfs_rest_ports = []
            if low + 2 * self._tfs_instance_count > hi:
                raise ValueError(
                    "not enough ports available in SAGEMAKER_SAFE_PORT_RANGE ({})".format(
                        self._sagemaker_port_range
                    )
                )
            # select non-overlapping grpc and rest ports based on tfs instance count
            for i in range(self._tfs_instance_count):
                self._tfs_grpc_ports.append(str(low + 2 * i))
                self._tfs_rest_ports.append(str(low + 2 * i + 1))
            # concat selected ports respectively in order to pass them to python service
            self._tfs_grpc_concat_ports = self._concat_ports(self._tfs_grpc_ports)
            self._tfs_rest_concat_ports = self._concat_ports(self._tfs_rest_ports)
        else:
            # just use the standard default ports
            self._tfs_grpc_ports = ["9000"]
            self._tfs_rest_ports = ["8501"]
            # provide single concat port here for default case
            self._tfs_grpc_concat_ports = "9000"
            self._tfs_rest_concat_ports = "8501"

        # set environment variable for python service
        os.environ["TFS_GRPC_PORTS"] = self._tfs_grpc_concat_ports
        os.environ["TFS_REST_PORTS"] = self._tfs_rest_concat_ports

    def _need_python_service(self):
        if (
            os.path.exists(INFERENCE_PATH)
            or os.path.exists(REQUIREMENTS_PATH)
            or os.path.exists(PYTHON_LIB_PATH)
        ):
            self._enable_python_service = True
        if os.environ.get("SAGEMAKER_MULTI_MODEL_UNIVERSAL_BUCKET") and os.environ.get(
            "SAGEMAKER_MULTI_MODEL_UNIVERSAL_PREFIX"
        ):
            self._enable_python_service = True

    def _concat_ports(self, ports):
        str_ports = [str(port) for port in ports]
        concat_str_ports = ",".join(str_ports)
        return concat_str_ports

    def _create_tfs_config(self):
        models = tfs_utils.find_models()

        if not models:
            raise ValueError("no SavedModel bundles found!")

        if self._tfs_default_model_name == "None":
            default_model = os.path.basename(models[0])
            if default_model:
                self._tfs_default_model_name = default_model
                log.info("using default model name: {}".format(self._tfs_default_model_name))
            else:
                log.info("no default model detected")

        # config (may) include duplicate 'config' keys, so we can't just dump a dict
        config = "model_config_list: {\n"
        for m in models:
            config += "  config: {\n"
            config += "    name: '{}'\n".format(os.path.basename(m))
            config += "    base_path: '{}'\n".format(m)
            config += "    model_platform: 'tensorflow'\n"

            config += "    model_version_policy: {\n"
            config += "      specific: {\n"
            for version in tfs_utils.find_model_versions(m):
                config += "        versions: {}\n".format(version)
            config += "      }\n"
            config += "    }\n"

            config += "  }\n"
        config += "}\n"

        log.info("tensorflow serving model config: \n%s\n", config)

        with open(self._tfs_config_path, "w", encoding="utf8") as f:
            f.write(config)

    def _setup_gunicorn(self):
        python_path_content = []
        python_path_option = ""

        bucket = os.environ.get("SAGEMAKER_MULTI_MODEL_UNIVERSAL_BUCKET", None)
        prefix = os.environ.get("SAGEMAKER_MULTI_MODEL_UNIVERSAL_PREFIX", None)

        if not os.path.exists(CODE_DIR) and bucket and prefix:
            self._download_scripts(bucket, prefix)

        if self._enable_python_service:
            lib_path_exists = os.path.exists(PYTHON_LIB_PATH)
            requirements_exists = os.path.exists(REQUIREMENTS_PATH)
            python_path_content = ["/opt/ml/model/code"]
            python_path_option = "--pythonpath "

            if lib_path_exists:
                python_path_content.append(PYTHON_LIB_PATH)

            if requirements_exists:
                if lib_path_exists:
                    log.warning(
                        "loading modules in '{}', ignoring requirements.txt".format(PYTHON_LIB_PATH)
                    )
                else:
                    log.info("installing packages from requirements.txt...")
                    pip_install_cmd = "pip3 install -r {}".format(REQUIREMENTS_PATH)
                    try:
                        subprocess.check_call(pip_install_cmd.split())
                    except subprocess.CalledProcessError:
                        log.error("failed to install required packages, exiting.")
                        self._stop()
                        raise ChildProcessError("failed to install required packages.")

        gunicorn_command = (
            "python3 /sagemaker/python_service.py -b unix:/tmp/gunicorn.sock -k {} --chdir /sagemaker "
            "--workers {} --threads {} --log-level {} --timeout {} "
        ).format(
            self._gunicorn_worker_class,
            self._gunicorn_workers,
            self._gunicorn_threads,
            self._gunicorn_loglevel,
            self._gunicorn_timeout_seconds,
        )

        log.info("gunicorn command: {}".format(gunicorn_command))
        self._gunicorn_command = gunicorn_command
        gunicorn_env = {
            "TFS_GRPC_PORTS": self._tfs_grpc_concat_ports,
            "TFS_REST_PORTS": self._tfs_rest_concat_ports,
            "SAGEMAKER_MULTI_MODEL": str(self._tfs_enable_multi_model_endpoint),
            "SAGEMAKER_TFS_WAIT_TIME_SECONDS": str(self._tfs_wait_time_seconds),
            "SAGEMAKER_TFS_INTER_OP_PARALLELISM": str(self._tfs_inter_op_parallelism),
            "SAGEMAKER_TFS_INTRA_OP_PARALLELISM": str(self._tfs_intra_op_parallelism),
            "SAGEMAKER_TFS_INSTANCE_COUNT": str(self._tfs_instance_count),
            "PYTHONPATH": ":".join(python_path_content),
            "SAGEMAKER_GUNICORN_WORKERS": str(self._gunicorn_workers),
        }
        if self._sagemaker_port_range is not None:
            gunicorn_env["SAGEMAKER_SAFE_PORT_RANGE"] = self._sagemaker_port_range
        log.info(f"gunicorn env: {gunicorn_env}")
        self._gunicorn_env = gunicorn_env

    def _download_scripts(self, bucket, prefix):
        log.info("checking boto session region ...")
        boto_session = boto3.session.Session()
        boto_region = boto_session.region_name
        if boto_region in ("us-iso-east-1", "us-gov-west-1"):
            raise ValueError("Universal scripts is not supported in us-iso-east-1 or us-gov-west-1")

        log.info("downloading universal scripts ...")
        client = boto3.client("s3")
        resource = boto3.resource("s3")
        # download files
        paginator = client.get_paginator("list_objects")
        for result in paginator.paginate(Bucket=bucket, Delimiter="/", Prefix=prefix):
            for file in result.get("Contents", []):
                destination = os.path.join(CODE_DIR, file.get("Key").split("/")[-1])
                if not os.path.exists(os.path.dirname(destination)):
                    os.makedirs(os.path.dirname(destination))
                resource.meta.client.download_file(bucket, file.get("Key"), destination)

    def _create_nginx_tfs_upstream(self):
        indentation = "    "
        tfs_upstream = ""
        for port in self._tfs_rest_ports:
            tfs_upstream += "{}server localhost:{};\n".format(indentation, port)
        tfs_upstream = tfs_upstream[len(indentation) : -2]

        return tfs_upstream

    def _create_nginx_config(self):
        template = self._read_nginx_template()
        pattern = re.compile(r"%(\w+)%")

        template_values = {
            "TFS_VERSION": self._tfs_version,
            "TFS_UPSTREAM": self._create_nginx_tfs_upstream(),
            "TFS_DEFAULT_MODEL_NAME": self._tfs_default_model_name,
            "NGINX_HTTP_PORT": self._nginx_http_port,
            "NGINX_LOG_LEVEL": self._nginx_loglevel,
            "FORWARD_PING_REQUESTS": GUNICORN_PING if self._use_gunicorn else JS_PING,
            "FORWARD_INVOCATION_REQUESTS": GUNICORN_INVOCATIONS
            if self._use_gunicorn
            else JS_INVOCATIONS,
            "PROXY_READ_TIMEOUT": str(self._nginx_proxy_read_timeout_seconds),
        }

        config = pattern.sub(lambda x: template_values[x.group(1)], template)
        log.info("nginx config: \n%s\n", config)

        with open("/sagemaker/nginx.conf", "w", encoding="utf8") as f:
            f.write(config)

    def _read_nginx_template(self):
        with open("/sagemaker/nginx.conf.template", "r", encoding="utf8") as f:
            template = f.read()
            if not template:
                raise ValueError("failed to read nginx.conf.template")

            return template

    def _enable_per_process_gpu_memory_fraction(self):
        nvidia_smi_exist = os.path.exists("/usr/bin/nvidia-smi")
        if self._tfs_instance_count > 1 and nvidia_smi_exist:
            return True

        return False

    def _get_number_of_gpu_on_host(self):
        nvidia_smi_exist = os.path.exists("/usr/bin/nvidia-smi")
        if nvidia_smi_exist:
            return len(
                subprocess.check_output(["nvidia-smi", "-L"]).decode("utf-8").strip().split("\n")
            )
        return 0

    def _calculate_per_process_gpu_memory_fraction(self):
        return round((1 - self._tfs_gpu_margin) / float(self._tfs_instance_count), 4)

    def _start_tfs(self):
        self._log_version("tensorflow_model_server --version", "tensorflow version info:")

        for i in range(self._tfs_instance_count):
            p = self._start_single_tfs(i)
            self._tfs.append(p)

    def _start_gunicorn(self):
        self._log_version("gunicorn --version", "gunicorn version info:")
        env = os.environ.copy()
        env["TFS_DEFAULT_MODEL_NAME"] = self._tfs_default_model_name
        env.update(self._gunicorn_env)
        p = subprocess.Popen(self._gunicorn_command.split(), env=env)
        log.info("started gunicorn (pid: %d)", p.pid)
        self._gunicorn = p

    def _start_nginx(self):
        self._log_version("/usr/sbin/nginx -V", "nginx version info:")
        p = subprocess.Popen("/usr/sbin/nginx -c /sagemaker/nginx.conf".split())
        log.info("started nginx (pid: %d)", p.pid)
        self._nginx = p

    def _log_version(self, command, message):
        try:
            output = (
                subprocess.check_output(command.split(), stderr=subprocess.STDOUT)
                .decode("utf-8", "backslashreplace")
                .strip()
            )
            log.info("{}\n{}".format(message, output))
        except subprocess.CalledProcessError:
            log.warning("failed to run command: %s", command)

    def _stop(self, *args):  # pylint: disable=W0613
        self._state = "stopping"
        log.info("stopping services")
        try:
            os.kill(self._nginx.pid, signal.SIGQUIT)
        except OSError:
            pass
        try:
            if self._gunicorn:
                os.kill(self._gunicorn.pid, signal.SIGTERM)
        except OSError:
            pass
        try:
            for tfs in self._tfs:
                os.kill(tfs.pid, signal.SIGTERM)
        except OSError:
            pass

        self._state = "stopped"
        log.info("stopped")

    def _wait_for_gunicorn(self):
        while True:
            if os.path.exists("/tmp/gunicorn.sock"):
                log.info("gunicorn server is ready!")
                return

    def _wait_for_tfs(self):
        for i in range(self._tfs_instance_count):
            tfs_utils.wait_for_model(
                self._tfs_rest_ports[i], self._tfs_default_model_name, self._tfs_wait_time_seconds
            )

    @contextmanager
    def _timeout(self, seconds):
        def _raise_timeout_error(signum, frame):
            raise TimeoutError("time out after {} seconds".format(seconds))

        try:
            signal.signal(signal.SIGALRM, _raise_timeout_error)
            signal.alarm(seconds)
            yield
        finally:
            signal.alarm(0)

    def _is_tfs_process(self, pid):
        for p in self._tfs:
            if p.pid == pid:
                return True
        return False

    def _find_tfs_process(self, pid):
        for index, p in enumerate(self._tfs):
            if p.pid == pid:
                return index
        return None

    def _restart_single_tfs(self, pid):
        instance_id = self._find_tfs_process(pid)
        if instance_id is None:
            raise ValueError("Cannot find tfs with pid: {};".format(pid))
        p = self._start_single_tfs(instance_id)
        self._tfs[instance_id] = p

    def _start_single_tfs(self, instance_id):
        cmd = tfs_utils.tfs_command(
            self._tfs_grpc_ports[instance_id],
            self._tfs_rest_ports[instance_id],
            self._tfs_config_path,
            self._tfs_enable_batching,
            self._tfs_batching_config_path,
            tfs_intra_op_parallelism=self._tfs_intra_op_parallelism,
            tfs_inter_op_parallelism=self._tfs_inter_op_parallelism,
            tfs_enable_gpu_memory_fraction=self._enable_per_process_gpu_memory_fraction(),
            tfs_gpu_memory_fraction=self._calculate_per_process_gpu_memory_fraction(),
        )
        log.info("tensorflow serving command: {}".format(cmd))

        num_gpus = self._get_number_of_gpu_on_host()
        if num_gpus > 1:
            # utilizing multi-gpu
            worker_env = os.environ.copy()
            worker_env["CUDA_VISIBLE_DEVICES"] = str(instance_id % num_gpus)
            p = subprocess.Popen(cmd.split(), env=worker_env)
            log.info(
                "started tensorflow serving (pid: {}) on GPU: {}".format(
                    p.pid, instance_id % num_gpus
                )
            )
        else:
            # cpu and single gpu
            p = subprocess.Popen(cmd.split())
            log.info("started tensorflow serving (pid: {})".format(p.pid))

        return p

    def _monitor(self):
        while True:
            pid, status = os.wait()

            if self._state != "started":
                break

            if pid == self._nginx.pid:
                log.warning("unexpected nginx exit (status: {}). restarting.".format(status))
                self._start_nginx()

            elif self._is_tfs_process(pid):
                log.warning(
                    "unexpected tensorflow serving exit (status: {}). restarting.".format(status)
                )
                try:
                    self._restart_single_tfs(pid)
                except (ValueError, OSError) as error:
                    log.error("Failed to restart tensorflow serving. {}".format(error))

            elif self._gunicorn and pid == self._gunicorn.pid:
                log.warning("unexpected gunicorn exit (status: {}). restarting.".format(status))
                self._start_gunicorn()

    def start(self):
        log.info("starting services")
        self._state = "starting"
        signal.signal(signal.SIGTERM, self._stop)

        if self._tfs_enable_batching:
            log.info("batching is enabled")
            tfs_utils.create_batching_config(self._tfs_batching_config_path)

        if self._tfs_enable_multi_model_endpoint:
            log.info("multi-model endpoint is enabled, TFS model servers will be started later")
        else:
            self._create_tfs_config()
            self._start_tfs()
            self._wait_for_tfs()

        self._create_nginx_config()

        if self._use_gunicorn:
            self._setup_gunicorn()
            self._start_gunicorn()
            # make sure gunicorn is up
            with self._timeout(seconds=self._gunicorn_timeout_seconds):
                self._wait_for_gunicorn()

        self._start_nginx()
        self._state = "started"
        self._monitor()
        self._stop()


if __name__ == "__main__":
    ServiceManager().start()
