import argparse
import copy
import glob
import logging
import os
import re
import shutil
import sys

import rl_coach.core_types
from rl_coach import logger
from rl_coach.agents.clipped_ppo_agent import ClippedPPOAgentParameters
from rl_coach.agents.policy_gradients_agent import PolicyGradientsAgentParameters
from rl_coach.base_parameters import Frameworks, TaskParameters, VisualizationParameters
from rl_coach.coach import CoachLauncher
from rl_coach.core_types import MaxDumpFilter, RunPhase, SelectedPhaseOnlyDumpFilter
from rl_coach.graph_managers.basic_rl_graph_manager import BasicRLGraphManager
from rl_coach.graph_managers.graph_manager import ScheduleParameters
from rl_coach.logger import screen
from rl_coach.utils import short_dynamic_import

from .configuration_list import ConfigurationList

screen.set_use_colors(False)  # Simple text logging so it looks good in CloudWatch


class CoachConfigurationList(ConfigurationList):
    """Helper Object for converting CLI arguments (or SageMaker hyperparameters)
    into Coach configuration.
    """

    # Being security-paranoid and not instantiating any arbitrary string the customer passes in
    ALLOWED_TYPES = {
        "Frames": rl_coach.core_types.Frames,
        "EnvironmentSteps": rl_coach.core_types.EnvironmentSteps,
        "EnvironmentEpisodes": rl_coach.core_types.EnvironmentEpisodes,
        "TrainingSteps": rl_coach.core_types.TrainingSteps,
        "Time": rl_coach.core_types.Time,
    }


class SageMakerCoachPresetLauncher(CoachLauncher):
    """Base class for training RL tasks using RL-Coach.
    Customers subclass this to define specific kinds of workloads, overriding these methods as needed.
    """

    def __init__(self):
        super().__init__()
        self.hyperparams = None

    def get_config_args(self, parser: argparse.ArgumentParser) -> argparse.Namespace:
        """Overrides the default CLI parsing.
        Sets the configuration parameters for what a SageMaker run should do.
        Note, this does not support the "play" mode.
        """
        # first, convert the parser to a Namespace object with all default values.
        empty_arg_list = []
        args, _ = parser.parse_known_args(args=empty_arg_list)
        parser = self.sagemaker_argparser()
        sage_args, unknown = parser.parse_known_args()

        # Now fill in the args that we care about.
        sagemaker_job_name = os.environ.get("sagemaker_job_name", "sagemaker-experiment")
        args.experiment_name = logger.get_experiment_name(sagemaker_job_name)

        # Override experiment_path used for outputs
        args.experiment_path = "/opt/ml/output/intermediate"
        rl_coach.logger.experiment_path = "/opt/ml/output/intermediate"  # for gifs

        args.checkpoint_save_dir = "/opt/ml/output/data/checkpoint"
        args.checkpoint_save_secs = 10  # should avoid hardcoding
        # onnx for deployment for mxnet (not tensorflow)
        save_model = sage_args.save_model == 1
        backend = os.getenv("COACH_BACKEND", "tensorflow")
        if save_model and backend == "mxnet":
            args.export_onnx_graph = True

        args.no_summary = True

        args.num_workers = sage_args.num_workers
        args.framework = Frameworks[backend]
        args.preset = sage_args.RLCOACH_PRESET
        # args.apply_stop_condition = True # uncomment for old coach behaviour

        self.hyperparameters = CoachConfigurationList()
        if len(unknown) % 2 == 1:
            raise ValueError("Odd number of command-line arguments specified. Key without value.")

        for i in range(0, len(unknown), 2):
            name = unknown[i]
            if name.startswith("--"):
                name = name[2:]
            else:
                raise ValueError("Unknown command-line argument %s" % name)
            val = unknown[i + 1]
            self.map_hyperparameter(name, val)

        return args

    def map_hyperparameter(self, name, value):
        """This is a good method to override where customers can specify custom shortcuts
        for hyperparameters.  Default takes everything starting with "rl." and sends it
        straight to the graph manager.
        """
        if name.startswith("rl."):
            self.apply_hyperparameter(name, value)
        else:
            raise ValueError("Unknown hyperparameter %s" % name)

    def apply_hyperparameter(self, name, value):
        """Save this hyperparameter to be applied to the graph_manager object when
        it's ready.
        """
        print("Applying RL hyperparameter %s=%s" % (name, value))
        self.hyperparameters.store(name, value)

    def default_preset_name(self):
        """
        Sub-classes will typically return a single hard-coded string.
        """
        try:
            # TODO: remove this after converting all samples.
            default_preset = self.DEFAULT_PRESET
            screen.warning(
                "Deprecated configuration of default preset.  Please implement default_preset_name()"
            )
            return default_preset
        except:
            pass
        raise NotImplementedError(
            "Sub-classes must specify the name of the default preset "
            + "for this RL problem.  This will be the name of a python "
            + "file (without .py) that defines a graph_manager variable"
        )

    def sagemaker_argparser(self) -> argparse.ArgumentParser:
        """
        Expose only the CLI arguments that make sense in the SageMaker context.
        """
        parser = argparse.ArgumentParser()

        # Arguably this would be cleaner if we copied the config from the base class argparser.
        parser.add_argument(
            "-n",
            "--num_workers",
            help="(int) Number of workers for multi-process based agents, e.g. A3C",
            default=1,
            type=int,
        )
        parser.add_argument(
            "-p",
            "--RLCOACH_PRESET",
            help="(string) Name of the file with the RLCoach preset",
            default=self.default_preset_name(),
            type=str,
        )
        parser.add_argument(
            "--save_model",
            help="(int) Flag to save model artifact after training finish",
            default=0,
            type=int,
        )
        return parser

    def path_of_main_launcher(self):
        """
        A bit of python magic to find the path of the file that launched the current process.
        """
        main_mod = sys.modules["__main__"]
        try:
            launcher_file = os.path.abspath(sys.modules["__main__"].__file__)
            return os.path.dirname(launcher_file)
        except AttributeError:
            # If __main__.__file__ is missing, then we're probably in an interactive python shell
            return os.getcwd()

    def preset_from_name(self, preset_name):
        preset_path = self.path_of_main_launcher()
        print("Loading preset %s from %s" % (preset_name, preset_path))
        preset_path = os.path.join(self.path_of_main_launcher(), preset_name) + ".py:graph_manager"
        graph_manager = short_dynamic_import(preset_path, ignore_module_case=True)
        return graph_manager

    def get_graph_manager_from_args(self, args):
        # First get the graph manager for the customer-specified (or default) preset
        graph_manager = self.preset_from_name(args.preset)
        # Now override whatever config is specified in hyperparameters.
        self.hyperparameters.apply_subset(graph_manager, "rl.")
        # Set framework
        # Note: Some graph managers (e.g. HAC preset) create multiple agents and the attribute is called agents_params
        if hasattr(graph_manager, "agent_params"):
            for network_parameters in graph_manager.agent_params.network_wrappers.values():
                network_parameters.framework = args.framework
        elif hasattr(graph_manager, "agents_params"):
            for ap in graph_manager.agents_params:
                for network_parameters in ap.network_wrappers.values():
                    network_parameters.framework = args.framework
        return graph_manager

    def _save_tf_model(self):
        ckpt_dir = "/opt/ml/output/data/checkpoint"
        model_dir = "/opt/ml/model"

        import tensorflow as tf  # importing tensorflow here so that MXNet docker image is compatible with this file.

        # Re-Initialize from the checkpoint so that you will have the latest models up.
        tf.train.init_from_checkpoint(
            ckpt_dir, {"main_level/agent/online/network_0/": "main_level/agent/online/network_0"}
        )
        tf.train.init_from_checkpoint(
            ckpt_dir, {"main_level/agent/online/network_1/": "main_level/agent/online/network_1"}
        )

        # Create a new session with a new tf graph.
        sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
        sess.run(tf.global_variables_initializer())  # initialize the checkpoint.

        # This is the node that will accept the input.
        input_nodes = tf.get_default_graph().get_tensor_by_name(
            "main_level/agent/main/online/" + "network_0/observation/observation:0"
        )
        # This is the node that will produce the output.
        output_nodes = tf.get_default_graph().get_operation_by_name(
            "main_level/agent/main/online/" + "network_1/ppo_head_0/policy"
        )
        # Save the model as a servable model.
        tf.saved_model.simple_save(
            session=sess,
            export_dir="model",
            inputs={"observation": input_nodes},
            outputs={"policy": output_nodes.outputs[0]},
        )
        # Move to the appropriate folder. Don't mind the directory, this just works.
        # rl-cart-pole is the name of the model. Remember it.
        shutil.move("model/", model_dir + "/model/tf-model/00000001/")
        # EASE will pick it up and upload to the right path.
        print("Success")

    def _save_onnx_model(self):
        from .onnx_utils import fix_onnx_model

        ckpt_dir = "/opt/ml/output/data/checkpoint"
        model_dir = "/opt/ml/model"
        # find latest onnx file
        # currently done by name, expected to be changed in future release of coach.
        glob_pattern = os.path.join(ckpt_dir, "*.onnx")
        onnx_files = [file for file in glob.iglob(glob_pattern, recursive=True)]
        if len(onnx_files) > 0:
            extract_step = lambda string: int(
                re.search("/(\d*)_Step.*", string, re.IGNORECASE).group(1)
            )
            onnx_files.sort(key=extract_step)
            latest_onnx_file = onnx_files[-1]
            # move to model directory
            filepath_from = os.path.abspath(latest_onnx_file)
            filepath_to = os.path.join(model_dir, "model.onnx")
            shutil.move(filepath_from, filepath_to)
            fix_onnx_model(filepath_to)
        else:
            screen.warning("No ONNX files found in {}".format(ckpt_dir))

    @classmethod
    def train_main(cls):
        """Entrypoint for training.
        Parses command-line arguments and starts training.
        """
        trainer = cls()
        trainer.launch()

        # Create model artifact for model.tar.gz
        parser = trainer.sagemaker_argparser()
        sage_args, unknown = parser.parse_known_args()
        if sage_args.save_model == 1:
            backend = os.getenv("COACH_BACKEND", "tensorflow")
            if backend == "tensorflow":
                trainer._save_tf_model()
            if backend == "mxnet":
                trainer._save_onnx_model()


class SageMakerCoachLauncher(SageMakerCoachPresetLauncher):
    """
    Older version of the launcher that doesn't use preset, but instead effectively has a single preset built in.
    """

    def __init__(self):
        super().__init__()
        screen.warning("DEPRECATION WARNING: Please switch to SageMakerCoachPresetLauncher")
        # TODO: Remove this whole class when nobody's using it any more.

    def define_environment(self):
        return NotImplementedEror(
            "Sub-class must define environment e.g. GymVectorEnvironment(level='your_module:YourClass')"
        )

    def get_graph_manager_from_args(self, args):
        """Returns the GraphManager object for coach to use to train by calling improve()"""
        # NOTE: TaskParameters are not configurable at this time.

        # Visualization
        vis_params = VisualizationParameters()
        self.config_visualization(vis_params)
        self.hyperparameters.apply_subset(vis_params, "vis_params.")

        # Schedule
        schedule_params = ScheduleParameters()
        self.config_schedule(schedule_params)
        self.hyperparameters.apply_subset(schedule_params, "schedule_params.")

        # Agent
        agent_params = self.define_agent()
        self.hyperparameters.apply_subset(agent_params, "agent_params.")

        # Environment
        env_params = self.define_environment()
        self.hyperparameters.apply_subset(env_params, "env_params.")

        graph_manager = BasicRLGraphManager(
            agent_params=agent_params,
            env_params=env_params,
            schedule_params=schedule_params,
            vis_params=vis_params,
        )

        return graph_manager

    def config_schedule(self, schedule_params):
        pass

    def define_agent(self):
        raise NotImplementedError(
            "Subclass must create define_agent() method which returns an AgentParameters object. e.g.\n"
            "   return rl_coach.agents.dqn_agent.DQNAgentParameters()"
        )

    def config_visualization(self, vis_params):
        vis_params.dump_gifs = True
        vis_params.video_dump_methods = [
            SelectedPhaseOnlyDumpFilter(RunPhase.TEST),
            MaxDumpFilter(),
        ]
        vis_params.print_networks_summary = True
        return vis_params
