# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import argparse
import json
import logging
import os
import subprocess
from pathlib import Path
from typing import Any

from azure.ai.ml import MLClient, load_component, load_data, load_environment
from azure.ai.ml.entities import Data, Environment
from azure.identity import DefaultAzureCredential

_logger = logging.getLogger(__file__)
logging.basicConfig(level=logging.INFO)


REG_CONFIG_FILENAME = "registration_config.json"
ENV_KEY = "environments"
COMP_KEY = "components"
DATA_KEY = "data"
SUBDIR_KEY = "nested_directories"


def parse_args():
    # setup arg parser
    parser = argparse.ArgumentParser()

    # add arguments
    parser.add_argument(
        "--workspace_config", type=str, help="Path to workspace config.json"
    )
    parser.add_argument(
        "--component_config", type=str, help="Path to component_config.json"
    )
    parser.add_argument("--base_directory", type=str, help="Path to base directory")

    # parse args
    args = parser.parse_args()

    # return args
    return args


def read_json_path(path: str) -> Any:
    _logger.info("Reading JSON file {0}".format(path))
    with open(path, "r") as f:
        result = json.load(f)
    return result


def process_file(input_file, output_file, replacements) -> None:
    with open(input_file, "r") as infile, open(output_file, "w") as outfile:
        for line in infile:
            for f, r in replacements.items():
                line = line.replace(f, r)
            outfile.write(line)


def process_directory(directory: Path, ml_client: MLClient, version: int) -> None:
    _logger.info("Processing: {0}".format(directory))
    assert directory.is_absolute()

    registration_file = directory / REG_CONFIG_FILENAME
    reg_config = read_json_path(registration_file.resolve())

    replacements = {"VERSION_REPLACEMENT_STRING": str(version)}

    _logger.info("Changing directory")
    os.chdir(directory)

    if ENV_KEY in reg_config.keys():
        for e in reg_config[ENV_KEY]:
            _logger.info("Registering environment: {0}".format(e))
            processed_file = e + ".processed"
            process_file(e, processed_file, replacements)
            curr_env: Environment = load_environment(processed_file)
            ml_client.environments.create_or_update(curr_env)
            _logger.info("Registered {0}".format(curr_env.name))
    else:
        _logger.info("No key for environments")

    if COMP_KEY in reg_config.keys():
        for c in reg_config[COMP_KEY]:
            _logger.info("Registering component: {0}".format(c))
            processed_file = c + ".processed"
            process_file(c, processed_file, replacements)
            curr_component = load_component(source=processed_file)
            ml_client.components.create_or_update(curr_component)
            _logger.info("Registered {0}".format(curr_component.name))
    else:
        _logger.info("No key for components")

    if DATA_KEY in reg_config.keys():
        _logger.info("Working through data entries")
        for data_info in reg_config[DATA_KEY]:
            script_file = data_info["script"]
            _logger.info("Running script {0}".format(script_file))
            subprocess.run(["python", script_file], check=True)
            for d in data_info["data_yamls"]:
                _logger.info("Processing {0}".format(d))
                processed_file = d + ".processed"
                process_file(d, processed_file, replacements)
                curr_dataset: Data = load_data(processed_file)
                ml_client.data.create_or_update(curr_dataset)
                _logger.info("Registered {0}".format(curr_dataset.name))
    else:
        _logger.info("No key for datasets")

    if SUBDIR_KEY in reg_config.keys():
        _logger.info("Working through nested directories")
        for d in reg_config[SUBDIR_KEY]:
            next_dir = directory / d
            process_directory(next_dir.resolve(), ml_client, version)
            os.chdir(directory)
    else:
        _logger.info("No subdirectories found for {0}".format(directory))


def main(args):
    ws_config = read_json_path(args.workspace_config)
    component_config = read_json_path(args.component_config)

    ml_client = MLClient(
        credential=DefaultAzureCredential(exclude_shared_token_cache_credential=True),
        subscription_id=ws_config["subscription_id"],
        resource_group_name=ws_config["resource_group"],
        workspace_name=ws_config["workspace_name"],
        logging_enable=True,
    )

    version: int = component_config["version"]

    process_directory(Path(args.base_directory).resolve(), ml_client, version)


if __name__ == "__main__":
    args = parse_args()
    main(args)
