#!/usr/bin/env python3

# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0.

"""
This is a module to handle protocol tests generation.
"""
import json
import os
import pathlib
import re
import sys
from concurrent.futures import ProcessPoolExecutor, wait, FIRST_COMPLETED, ALL_COMPLETED

from codegen.legacy_c2j_cpp_gen import LegacyC2jCppGen
from codegen.model_utils import SERVICE_MODEL_FILENAME_PATTERN, ServiceModel, ModelUtils

PROTOCOL_TESTS_BASE_DIR = "tools/code-generation/protocol-tests"
PROTOCOL_TESTS_CLIENT_MODELS = PROTOCOL_TESTS_BASE_DIR + "/api-descriptions"
PROTOCOL_TESTS_ENDPOINT_RULES = "endpoint-rule-set.json"  # Dummy endpoint ruleset
PROTOCOL_TESTS_DEFINITION_SETS = ["input", "output"]

PROTOCOL_TESTS_GENERATED_CLIENTS_DIR = "generated/protocol-tests/test-clients"
PROTOCOL_GENERATED_TESTS_DIR = "generated/protocol-tests/tests"

UNSUPPORTED_CLIENTS = {"rpcv2protocol"  # RPC V2 CBOR support is not implemented on this SDK
                       }
UNSUPPORTED_TESTS = {"smithy-rpc-v2-cbor"}

# Regexp to parse C2J model filename to extract service name and date version
TEST_DEFINITION_FILENAME_PATTERN = re.compile(
    "^"
    "(?P<name>.+)"
    ".json$"
)


class ProtocolTestsGen(object):
    """A wrapper for Protocol tests generator for C++ SDK
    """

    class ProtoTestC2jClientModelMetadata:
        def __init__(self, filename: str, model_path: str, md: dict):
            self.service_name = SERVICE_MODEL_FILENAME_PATTERN.match(filename).group("service")
            self.model_path = model_path
            self.md = md

    class ProtocolTestModel:
        def __init__(self, test_type: str, test_name: str, c2j_test_model: str, c2j_client_md):
            self.test_type = test_type  # ex: input our output
            self.service_name = c2j_client_md.service_name
            self.test_name = test_name
            #  File paths to model files
            self.c2j_test_model = c2j_test_model
            self.c2j_client_model = c2j_client_md.model_path

    def __init__(self,
                 args: dict):
        sdk_root_dir = pathlib.Path(__file__).parents[3]
        self.debug = args.get("debug", False)
        self.client_models_dir = str(pathlib.Path(f"{sdk_root_dir}/{PROTOCOL_TESTS_CLIENT_MODELS}").resolve())
        self.test_definitions_dir = str(pathlib.Path(f"{sdk_root_dir}/{PROTOCOL_TESTS_BASE_DIR}").resolve())

        self.generated_test_clients_dir = str(
            pathlib.Path(f"{sdk_root_dir}/{PROTOCOL_TESTS_GENERATED_CLIENTS_DIR}").resolve())
        self.generated_tests_dir = str(pathlib.Path(f"{sdk_root_dir}/{PROTOCOL_GENERATED_TESTS_DIR}").resolve())

        self.c2j_client_generator = LegacyC2jCppGen(args, dict())
        self.c2j_client_generator.path_to_api_definitions = self.client_models_dir
        self.c2j_client_generator.path_to_endpoint_rules = str(
            pathlib.Path(f"{sdk_root_dir}/{PROTOCOL_TESTS_BASE_DIR}").resolve())
        self.c2j_client_generator.output_location = PROTOCOL_TESTS_GENERATED_CLIENTS_DIR

        self.c2j_tests_generator = LegacyC2jCppGen(args, dict())
        self.c2j_tests_generator.path_to_api_definitions = ""
        self.c2j_tests_generator.path_to_endpoint_rules = ""
        self.c2j_tests_generator.output_location = PROTOCOL_GENERATED_TESTS_DIR

    def generate(self, executor: ProcessPoolExecutor, max_workers: int):
        """
        Generate protocol tests (test clients and a corresponding set of tests)
        :param executor:
        :param max_workers:
        :return:
        """
        if self._generate_test_clients(executor, max_workers) == 0:
            return self._generate_tests(executor, max_workers)
        return -1

    def _generate_test_clients(self, executor: ProcessPoolExecutor, max_workers: int):
        self.c2j_client_generator.build_generator(self.c2j_client_generator.path_to_generator)

        service_models = self._collect_test_client_models()
        os.makedirs(self.generated_test_clients_dir, exist_ok=True)

        pending = set()
        done = set()
        sys.stdout.flush()
        for service, model_files in service_models.items():
            while len(pending) >= max_workers:
                new_done, pending = wait(pending, return_when=FIRST_COMPLETED)
                done.update(new_done)

            task = executor.submit(self._generate_test_client,
                                   service,
                                   model_files,
                                   PROTOCOL_TESTS_GENERATED_CLIENTS_DIR)
            pending.add(task)

        new_done, _ = wait(pending, return_when=ALL_COMPLETED)
        done.update(new_done)

        failures = set()
        for result in done:
            try:
                service, status = result.result()  # will rethrow any exceptions
                if status != 0:
                    raise RuntimeError(f"Protocol test client {service} (re)generation failed: {status}")
            except Exception as exc:
                failures.add(f"Protocol test client (re)generation failed with error.\n    Exception: {exc}\n"
                             f"stderr: {getattr(exc, 'stderr', None)}")

        if len(failures):
            print(f"Code generation failed, processed {len(done)} packages. "
                  f"Encountered {len(failures)} failures:\n")
            for failure in failures:
                print(failure)

        if len(failures):
            return -1
        return 0

    def _generate_test_client(self,
                              service_name: str,
                              model_files: ServiceModel,
                              output_dir: str):
        service_name, status = self.c2j_client_generator.generate_client(service_name, model_files, output_dir, None)
        return service_name, status

    def _collect_test_client_models(self) -> dict:
        service_models = dict()
        model_files = os.listdir(self.client_models_dir)
        for filename in model_files:
            if not os.path.isfile("/".join([str(self.client_models_dir), filename])):
                continue
            match = SERVICE_MODEL_FILENAME_PATTERN.match(filename)

            service_model_name = match.group("service")
            _ = match.group("date")
            if service_model_name in UNSUPPORTED_CLIENTS:
                print(f"Skipping protocol tests client generation: {filename}")
                continue

            use_smithy = ModelUtils.is_smithy_enabled(service_model_name, self.client_models_dir, filename)
            service_models[service_model_name] = ServiceModel(service_model_name, filename,
                                                              PROTOCOL_TESTS_ENDPOINT_RULES, None, use_smithy)
        return service_models

    def _get_client_models_metadata(self) -> list:
        models = list()
        model_files = os.listdir(self.client_models_dir)
        for filename in sorted(model_files):
            if not os.path.isfile("/".join([self.client_models_dir, filename])):
                continue
            model_abspath = str(pathlib.Path(f"{self.client_models_dir}/{filename}").resolve())
            with open(model_abspath, 'r') as file_content:
                try:
                    c2j_model = json.load(file_content)
                    model_metadata = self.ProtoTestC2jClientModelMetadata(filename, model_abspath,
                                                                          c2j_model.get("metadata"))
                    models.append(model_metadata)
                except Exception as exc:
                    print(f"ERROR: unexpected file content in protocol tests clients dir {self.client_models_dir}. "
                          f"Expected c2j client model, but json metadata kew is missing: {exc}")
        return models

    def _collect_test_definition_models(self) -> dict:
        all_test_clients_md = self._get_client_models_metadata()

        test_models = dict()  # ex: "{input: {ec2: ProtocolTestModel}, output: {ec2: ProtocolTestModel}}"

        for test_def_group in PROTOCOL_TESTS_DEFINITION_SETS:
            model_files = os.listdir(f"{self.test_definitions_dir}/{test_def_group}")
            for filename in model_files:
                if not os.path.isfile(f"{self.test_definitions_dir}/{test_def_group}/{filename}"):
                    continue
                match = TEST_DEFINITION_FILENAME_PATTERN.match(filename)

                test_def_name = match.group("name")
                if test_def_name in UNSUPPORTED_TESTS:
                    print(f"Skipping protocol tests generation: {test_def_group}/{filename}")
                    continue

                test_def_path = str(pathlib.Path(f"{self.test_definitions_dir}/{test_def_group}/{filename}").resolve())

                def _get_corresponding_test_client(test_clients_md: list, test_path: str) -> list:
                    # Get c2j client models matching the test suite
                    # more than 1 is possible (ex: xml and xml with namespace clients for a single test suite)
                    result = list()
                    with open(test_path, 'r') as file_content:
                        try:
                            proto_test_model = json.load(file_content)
                            proto_test_md = proto_test_model[0].get("metadata")
                            for c2j_md in test_clients_md:
                                for field_to_match in ["apiVersion", "protocols", "jsonVersion", "targetPrefix"]:
                                    if proto_test_md.get(field_to_match, None) != c2j_md.md.get(field_to_match, None):
                                        break
                                else:
                                    result.append(c2j_md)
                        except Exception as exc:
                            print(f"ERROR: unexpected file content in protocol tests {test_def_path}. "
                                  f"Expected c2j protocol test, but json metadata kew is missing: {exc}")
                    return result

                test_clients_for_suite = _get_corresponding_test_client(all_test_clients_md, test_def_path)

                if test_clients_for_suite is None or len(test_clients_for_suite) == 0:
                    raise Exception(f"ERROR: Unable to find C2J client model for the test suite: {test_def_path}")

                for index, client_md in enumerate(test_clients_for_suite):
                    if index == 0:
                        test_def_key = test_def_name
                    else:
                        test_def_key = f"{test_def_name}-{index}"
                    assert test_models.get(test_def_group, dict()).get(test_def_key, None) is None, \
                        f"This test suite {test_def_group}/{test_def_key} already exists: {test_models}"
                    if self.debug:
                        print("Protocol test generation task:\t"
                              f"{test_def_path.split('/')[-1]} with {client_md.model_path.split('/')[-1]}")
                    if test_def_group not in test_models:
                        test_models[test_def_group] = dict()
                    test_models[test_def_group][test_def_key] = self.ProtocolTestModel(test_type=test_def_group,
                                                                                       test_name=test_def_key,
                                                                                       c2j_test_model=test_def_path,
                                                                                       c2j_client_md=client_md)
        return test_models

    def _generate_tests(self, executor: ProcessPoolExecutor, max_workers: int):
        test_models = self._collect_test_definition_models()

        pending = set()
        done = set()
        sys.stdout.flush()
        for test_def_group, test_suites in test_models.items():
            os.makedirs(test_def_group, exist_ok=True)

            for protocol, test_models in test_suites.items():
                while len(pending) >= max_workers:
                    new_done, pending = wait(pending, return_when=FIRST_COMPLETED)
                    done.update(new_done)

                task = executor.submit(self._generate_single_protocol_test,
                                       test_models)
                pending.add(task)

        new_done, _ = wait(pending, return_when=ALL_COMPLETED)
        done.update(new_done)

        failures = set()
        for result in done:
            try:
                service, status = result.result()  # will rethrow any exceptions
                if status != 0:
                    raise RuntimeError(f"Protocol test client {service} (re)generation failed: {status}")
            except Exception as exc:
                failures.add(f"Protocol test client (re)generation failed with error.\n    Exception: {exc}\n"
                             f"stderr: {getattr(exc, 'stderr', None)}")

        if len(failures):
            print(f"Code generation failed, processed {len(done)} packages. "
                  f"Encountered {len(failures)} failures:\n")
            for failure in failures:
                print(failure)

        if len(failures):
            return -1
        return 0

    def _generate_single_protocol_test(self, models: ProtocolTestModel):
        """Call java generator to generate a single protocol test suite

        :param test_group: ex: "input" or "output"
        :param name: ex: "ec2", "json", "xml"
        :param models: ProtocolTestModel
        :return:
        """
        generator_jar = self.c2j_tests_generator.path_to_generator + "/" + self.c2j_tests_generator.GENERATOR_JAR
        run_command = list()
        run_command.append("java")
        run_command += ["-jar", generator_jar]
        run_command += ["--inputfile", models.c2j_client_model]
        run_command += ["--protocol-tests", models.c2j_test_model]
        run_command += ["--protocol-tests-type", models.test_type]
        run_command += ["--protocol-tests-name", models.test_name]
        run_command += ["--service", models.service_name]
        run_command += ["--outputfile", "STDOUT"]
        run_command.append("--generate-tests")

        name_for_logging = f"protocol test {models.test_type}/{models.test_name}"
        output_zip_file = self.c2j_tests_generator.run_generator_once(name_for_logging,
                                                                      run_command, "STDOUT")
        dir_to_delete = f"{self.generated_tests_dir}/{models.test_type}/{models.test_name}"
        dir_to_extract = f"{self.generated_tests_dir}/{models.test_type}"
        name_for_logging, status = self.c2j_tests_generator.extract_zip(output_zip_file, name_for_logging,
                                                                        dir_to_extract, dir_to_delete)

        return name_for_logging, status
