import os
import shlex
import signal
import socket
import stat
import subprocess
import sys
import textwrap
import time
from contextlib import contextmanager

import sagemaker_containers
from retrying import retry
from sagemaker_containers import _logging
from sagemaker_containers.beta import framework

logger = _logging.get_logger()

# MPI files.
_MPI_SCRIPT = "/mpi_script.sh"
_MPI_IS_RUNNING = "/mpi_is_running"
_MPI_IS_FINISHED = "/mpi_is_finished"
_CHANGE_HOSTNAME_LIBRARY = "/libchangehostname.so"


def _change_hostname(current_host):
    """Compiles a shared library to correct the behavior of the gethostname system call,
        which OpenMPI depends on.

    Args:
        current_host (str): name of the current host, such as algo-1, algo-2, etc.
    """
    os.system("/change-hostname.sh {}".format(current_host))


def _start_ssh_daemon():
    """Starts the ssh deamon"""
    subprocess.Popen(["/usr/sbin/sshd", "-D"])


def _setup_mpi_environment(env):
    """Setup MPI environment, i.e. executing change hostname scrip and starting ssh deamon."""
    _change_hostname(env.current_host)
    _start_ssh_daemon()


def _can_connect(host, port, s):
    """Checks if the connection to provided ``host`` and ``port`` is possible or not."""
    try:
        print("Testing connection to host {}".format(host))
        s.connect((host, port))
        s.close()
        print("Can connect to host {}".format(host))
        return True
    except socket.error:
        print("Can't connect to host {}".format(host))
        return False


def _create_mpi_script(env, train_script, train_script_args):
    """Creates a MPI script with user provided information.

        For distributed training: the 'master node' runs mpirun with this script,
        '/mpi_script.sh'.

        This script creates a file '/mpi_is_running' that worker nodes use to
        determine whether training # (started by MPI from the master node) is still running.

        Processes on worker nodes use # /mpi_is_finished file to determine when to exit.

    Args:
        env (TrainingEnv): an instance of the training environment.
    """
    hyperparameters = framework.mapping.to_cmd_args(env.hyperparameters)
    channels = framework.mapping.to_cmd_args(env.channel_input_dirs)

    python_cmd = [sys.executable, train_script]
    python_cmd.extend(train_script_args)
    python_cmd.extend(hyperparameters)
    python_cmd.extend(channels)

    content = textwrap.dedent(
        """#!/usr/bin/env bash
touch /mpi_is_running
%s
EXIT_CODE=$?
touch /mpi_is_finished
exit ${EXIT_CODE}
"""
        % " ".join(python_cmd)
    )

    with open(_MPI_SCRIPT, "w") as w:
        w.write(content)

    st = os.stat(_MPI_SCRIPT)
    os.chmod(_MPI_SCRIPT, st.st_mode | stat.S_IEXEC)


class MPIMaster(object):
    """MPI Master

    Args:
        env (TrainingEnv): an instance of the training environment.
        process_per_host (int): Number of processes per host to be executed by MPI
        instance_type (str): Type of instance used for this job. It will be "local" for local mode. Its used to
                        perform different setup for local mode or sagemaker mode.
    """

    def __init__(self, env, process_per_host, instance_type):
        self.env = env
        self.process_per_host = process_per_host
        self.instance_type = instance_type

    def _wait_for_worker_nodes_to_start_sshd(self, hosts, interval=1, timeout_in_seconds=180):
        """Wait for worker nodes to start their ssh deamon to allow MPI communication."""
        with timeout(seconds=timeout_in_seconds):
            while hosts:
                print("hosts that aren't SSHable yet: {}".format(str(hosts)))
                for host in hosts:
                    ssh_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                    if _can_connect(host, 22, ssh_socket):
                        print("Host: {} is sshable now.".format(host))
                        hosts.remove(host)
                time.sleep(interval)

    def _run_mpi_on_all_nodes(self):
        """Run MPI command to execute MPI_SCRIPT on all hosts."""
        mpi_command = self._build_mpi_command()
        cmd = shlex.split(mpi_command)

        framework.logging.log_script_invocation(cmd, self.env.to_env_vars(), logger)

        print("MPI Command: {}".format(mpi_command))
        with open(_MPI_SCRIPT) as f:
            print("Running user script:\n\n%s", f.read())

        subprocess.check_call(cmd)

    def _build_mpi_command(self):
        """Build MPI command."""
        num_hosts = len(self.env.hosts)
        num_processes = self.process_per_host * num_hosts

        # By default, use one process per GPU, or one process per node (if training with CPU).
        host_list = (
            self.env.hosts
            if self.process_per_host == 1
            else [host + ":{}".format(self.process_per_host) for host in self.env.hosts]
        )

        print(
            "Env Hosts: {} Hosts: {} process_per_hosts: {} num_processes: {}".format(
                self.env.hosts, host_list, self.process_per_host, num_processes
            )
        )
        credential_vars = ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY", "AWS_SESSION_TOKEN"]

        interface_name = interface_name = self.env.network_interface_name
        if self.instance_type == "local":
            interface_name = "eth0"

        print("network interface name:" + interface_name + " " + str(self.instance_type))

        mpi_command = (
            "mpirun --host {}".format(",".join(host_list))
            + " -np {} ".format(num_processes)
            + " --allow-run-as-root"
            + " --display-map"
            + " --tag-output"
            + " -mca btl_tcp_if_include {}".format(interface_name)
            + " -mca oob_tcp_if_include {}".format(interface_name)
            + " -x NCCL_SOCKET_IFNAME={}".format(interface_name)
            + " --mca plm_rsh_no_tree_spawn 1"
            + " -mca orte_abort_on_non_zero_status 1"
            + " -x NCCL_MIN_NRINGS=8 -x NCCL_DEBUG=INFO"
            + " -x LD_LIBRARY_PATH -x PATH"
            + " -x LD_PRELOAD={}".format(_CHANGE_HOSTNAME_LIBRARY)
        )

        for v in credential_vars:
            if v in os.environ:
                mpi_command += " -x {}".format(v)

        for name, value in self.env.to_env_vars().items():
            mpi_command += ' -x {}="{}"'.format(name, value)

        mpi_command += " {}".format(_MPI_SCRIPT)

        return mpi_command

    def __call__(self):
        self._wait_for_worker_nodes_to_start_sshd(self.env.hosts.copy())
        self._run_mpi_on_all_nodes()

    def is_master(self, hosts, current_host):
        """Checks if the current host is master or worker."""
        print("Hosts: " + str(hosts) + " current host: " + str(current_host))
        return current_host == sorted(list(hosts))[0]


class MPIWorker(object):
    """MPI Worker"""

    @retry(
        stop_max_delay=30000 * 1000, wait_fixed=1000, retry_on_result=lambda result: result is False
    )
    def _wait_for_mpi_to_start_running(self):
        """Wait and retry loop until the MPI training starts on this worker."""
        return os.path.isfile(_MPI_IS_RUNNING)

    @retry(wait_fixed=5000, retry_on_result=lambda result: result is False)
    def _wait_until_mpi_stops_running(self):
        """Wait and retry loop until the MPI training is finished on this worker."""
        return os.path.isfile(_MPI_IS_FINISHED)

    def __call__(self, env):
        current_host = env.current_host

        print("Worker node {} is waiting for MPI to start training process".format(current_host))
        self._wait_for_mpi_to_start_running()

        print("MPI started training process on worker node {}".format(current_host))

        self._wait_until_mpi_stops_running()
        print("Training process started by MPI on worker node %s stopped", current_host)


class TimeoutError(Exception):
    pass


@contextmanager
def timeout(seconds=0, minutes=0, hours=0):
    """
    Add a signal-based timeout to any block of code.
    If multiple time units are specified, they will be added together to determine time limit.
    Usage:
    with timeout(seconds=5):
        my_slow_function(...)
    Args:
        - seconds: The time limit, in seconds.
        - minutes: The time limit, in minutes.
        - hours: The time limit, in hours.
    """

    limit = seconds + 60 * minutes + 3600 * hours

    def handler(signum, frame):  # pylint: disable=W0613
        raise TimeoutError("timed out after {} seconds".format(limit))

    try:
        signal.signal(signal.SIGALRM, handler)
        signal.setitimer(signal.ITIMER_REAL, limit)
        yield
    finally:
        signal.alarm(0)


class MPILauncher(object):
    """
    MPI launcher, it can be used by algorithms supporting the MPI based distributed training.

    Args:
        train_script (str): Train script to executed by the ``MPILauncher``
        train_script_args (list): List of args that are passed to the ``train_script`` to be executed by ``MPILauncher``
        num_of_processes_per_host (int): Number of processes per host to be executed by MPI
        instance_type (str): Type of instance used for this job. It will be "local" for local mode. Its used to perform
                            different setup for local mode or sagemaker mode.

    """

    def __init__(
        self, train_script, train_script_args=None, num_of_processes_per_host=1, instance_type=False
    ):

        self._train_script = train_script
        self._train_script_args = train_script_args
        self._num_of_processes_per_host = num_of_processes_per_host
        self._instance_type = instance_type

    def mpi_run(self):
        env = sagemaker_containers.training_env()
        print("MPI requested with process per hosts: {}".format(self._num_of_processes_per_host))

        _setup_mpi_environment(env)
        _create_mpi_script(env, self._train_script, self._train_script_args)

        mpi_master = MPIMaster(env, self._num_of_processes_per_host, self._instance_type)
        if mpi_master.is_master(env.hosts, env.current_host):
            print("Inside Master")
            mpi_master()
        else:
            print("Inside Worker")
            MPIWorker()(env)
