# coding=utf-8
# ----------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License file in the project root for license information.
# ----------------------------------------------------------------------------------------------

"""
common: Defines common utility functions and components.

"""

import base64
import json
import logging
import re
from typing import Dict, List, Optional

from knack.log import get_logger

logger = get_logger(__name__)


def parse_kvp_nargs(kvp_nargs: List[str]) -> dict:
    """
    Parses kvp nargs into a dict handling values of null and empty string.
    """
    result = {}
    if not kvp_nargs:
        return result

    for item in kvp_nargs:
        key, sep, value = item.partition("=")
        result[key] = value if sep else None
    return result


def assemble_nargs_to_dict(hash_list: List[str]) -> Dict[str, str]:
    result = {}
    if not hash_list:
        return result
    for hash in hash_list:
        if "=" not in hash:
            logger.warning(
                "Skipping processing of '%s', input format is key=value | key='value value'.",
                hash,
            )
            continue
        split_hash = hash.split("=", 1)
        result[split_hash[0]] = split_hash[1]
    for key in result:
        if not result.get(key):
            logger.warning(
                "No value assigned to key '%s', input format is key=value | key='value value'.",
                key,
            )
    return result


def build_query(cmd, subscription_id: Optional[str] = None, custom_query: Optional[str] = None, **kwargs):
    url = "/providers/Microsoft.ResourceGraph/resources?api-version=2022-10-01"
    subscriptions = [subscription_id] if subscription_id else []
    payload = {"subscriptions": subscriptions, "query": "Resources ", "options": {}}

    # TODO: add more query options as they pop up
    if kwargs.get("name"):
        payload["query"] += f'| where name =~ "{kwargs.get("name")}" '
    if kwargs.get("resource_group"):
        payload["query"] += f'| where resourceGroup =~ "{kwargs.get("resource_group")}" '
    if kwargs.get("location"):
        payload["query"] += f'| where location =~ "{kwargs.get("location")}" '
    if kwargs.get("type"):
        payload["query"] += f'| where type =~ "{kwargs.get("type")}" '
    if custom_query:
        payload["query"] += custom_query
    payload["query"] += "| project id, location, name, resourceGroup, properties, tags, type, subscriptionId"
    if kwargs.get("additional_project"):
        payload["query"] += f', {kwargs.get("additional_project")}'

    return _process_raw_request(cmd, url, "POST", payload)


def _process_raw_request(cmd, url: str, method: str, payload: Optional[dict] = None, keyword: str = "data"):
    # since I don't want to download the resourcegraph sdk - we are stuck with this
    # note that we are trying to limit dependencies
    from azure.cli.core.util import send_raw_request

    result = []
    skip_token = "sentinel"
    while skip_token:
        try:
            body = json.dumps(payload) if payload is not None else None
            res = send_raw_request(cli_ctx=cmd.cli_ctx, url=url, method=method, body=body)
        except Exception as e:
            raise e
        if not res.content:
            return
        json_response = res.json()
        result.extend(json_response[keyword])
        skip_token = json_response.get("$skipToken")
        if skip_token:
            if not payload:
                payload = {"options": {}}
            if "options" not in payload:
                payload["options"] = {}
            payload["options"]["$skipToken"] = skip_token

    return result


def get_timestamp_now_utc(format: str = "%Y-%m-%dT%H:%M:%S") -> str:
    from datetime import datetime, timezone

    timestamp = datetime.now(timezone.utc).strftime(format)
    return timestamp


def set_log_level(log_name: str, log_level: int = logging.DEBUG):
    lgr = logging.getLogger(log_name)
    lgr.setLevel(log_level)


def generate_secret(byte_length=32):
    """
    Generate cryptographically secure secret.
    """
    import secrets

    token_bytes = secrets.token_bytes(byte_length)
    return base64.b64encode(token_bytes).decode("utf8")


def url_safe_hash_phrase(phrase: str) -> str:
    from hashlib import sha256

    return sha256(phrase.encode("utf8")).hexdigest()


def url_safe_random_chars(count: int) -> str:
    import secrets

    token = ""
    while len(token) < count:
        _t = secrets.token_urlsafe()
        _t = _t.replace("-", "")
        _t = _t.replace("_", "")
        token += _t

    return token[:count]


def ensure_azure_namespace_path():
    """
    Run prior to importing azure namespace packages (azure.*) to ensure the
    extension root path is configured for package import.
    """

    import os
    import sys

    from azure.cli.core.extension import get_extension_path

    from ...constants import EXTENSION_NAME

    ext_path = get_extension_path(EXTENSION_NAME)
    if not ext_path:
        return

    ext_azure_dir = os.path.join(ext_path, "azure")
    if os.path.isdir(ext_azure_dir):
        import azure

        if getattr(azure, "__path__", None) and ext_azure_dir not in azure.__path__:  # _NamespacePath /w PEP420
            if isinstance(azure.__path__, list):
                azure.__path__.insert(0, ext_azure_dir)
            else:
                azure.__path__.append(ext_azure_dir)

    if sys.path and sys.path[0] != ext_path:
        sys.path.insert(0, ext_path)


def run_host_command(command: str, shell_mode: bool = True):
    from shlex import quote, split
    from subprocess import run

    if not command:
        raise ValueError("command value is required.")

    logger.debug("Running host command: %s", command)
    command = quote(command)
    split_command = split(command)

    try:
        return run(split_command, capture_output=True, check=False, shell=shell_mode)
    except FileNotFoundError:
        pass


def is_env_flag_enabled(env_flag_key: str) -> bool:
    from os import getenv

    return getenv(env_flag_key, "false").lower() in ["true", "1", "y"]


def is_enabled_str(value: Optional[str]) -> bool:
    """
    Converts an intended property str value (such as from k8s enum) to bool
    """

    if isinstance(value, str):
        return value.lower() == "enabled"

    return False


def should_continue_prompt(confirm_yes: Optional[bool] = None, context: str = "Deletion") -> bool:
    from rich.prompt import Confirm

    if not confirm_yes and not Confirm.ask("Continue?"):
        logger.warning(f"{context} cancelled.")
        return False

    return True


def insert_newlines(s: str, every: int = 79) -> str:
    return "\n".join(s[i : i + every] for i in range(0, len(s), every))


def parse_dot_notation(pairs: List[str]) -> dict:
    """
    ["a.b=value1", "a.c.d=value2"] -> {"a": {"b": "value1", "c": {"d": "value2"}}}
    """
    result = {}
    for pair in pairs:
        path, sep, value = pair.partition("=")
        if not sep:
            continue
        keys = path.strip().split(".")
        current = result
        for key in keys[:-1]:
            if not isinstance(current.get(key), dict):
                current[key] = {}
            current = current[key]
        current[keys[-1]] = value.strip()
    return result


def upsert_by_discriminator(initial: List[dict], disc_key: str, config: dict) -> List[dict]:
    disc = config.get(disc_key)
    for i, d in enumerate(initial):
        if d.get(disc_key) == disc:
            initial[i] = config
            return initial
    initial.append(config)
    return initial


def chunk_list(data: list, chunk_len: int, data_size: int = 1024, size_unit: str = "kb") -> List[list]:
    if size_unit.lower() == "mb":
        data_size *= 1024

    result = []
    current_chunk = []

    for item in data:
        current_chunk.append(item)

        serialized_size = len(json.dumps(current_chunk).encode("utf-8")) / 1024  # convert bytes to kb

        if len(current_chunk) > chunk_len or serialized_size > data_size:
            current_chunk.pop()
            result.append(current_chunk)
            current_chunk = [item]

    if current_chunk:
        result.append(current_chunk)

    return result


def to_safe_filename(name: str) -> str:
    return re.sub(r"[^\w\-.]", "_", name).strip(".")
