"""Apache Spark on Amazon Athena Module."""

from __future__ import annotations

import logging
import time
from typing import TYPE_CHECKING, Any, Dict, cast

import boto3

from awswrangler import _utils, exceptions

_logger: logging.Logger = logging.getLogger(__name__)

if TYPE_CHECKING:
    from mypy_boto3_athena.type_defs import (
        EngineConfigurationTypeDef,
        GetCalculationExecutionResponseTypeDef,
        GetCalculationExecutionStatusResponseTypeDef,
        GetSessionStatusResponseTypeDef,
    )

_SESSION_FINAL_STATES: list[str] = ["IDLE", "TERMINATED", "DEGRADED", "FAILED"]
_CALCULATION_EXECUTION_FINAL_STATES: list[str] = ["COMPLETED", "FAILED", "CANCELED"]
_SESSION_WAIT_POLLING_DELAY: float = 5.0  # SECONDS
_CALCULATION_EXECUTION_WAIT_POLLING_DELAY: float = 5.0  # SECONDS


def _wait_session(
    session_id: str,
    boto3_session: boto3.Session | None = None,
    athena_session_wait_polling_delay: float = _SESSION_WAIT_POLLING_DELAY,
) -> "GetSessionStatusResponseTypeDef":
    client_athena = _utils.client(service_name="athena", session=boto3_session)

    response: "GetSessionStatusResponseTypeDef" = client_athena.get_session_status(SessionId=session_id)
    state: str = response["Status"]["State"]

    while state not in _SESSION_FINAL_STATES:
        time.sleep(athena_session_wait_polling_delay)
        response = client_athena.get_session_status(SessionId=session_id)
        state = response["Status"]["State"]
    _logger.debug("Session state: %s", state)
    _logger.debug("Session state change reason: %s", response["Status"].get("StateChangeReason"))
    if state in ["FAILED", "DEGRADED", "TERMINATED"]:
        raise exceptions.SessionFailed(response["Status"].get("StateChangeReason"))
    return response


def _wait_calculation_execution(
    calculation_execution_id: str,
    boto3_session: boto3.Session | None = None,
    athena_calculation_execution_wait_polling_delay: float = _CALCULATION_EXECUTION_WAIT_POLLING_DELAY,
) -> "GetCalculationExecutionStatusResponseTypeDef":
    client_athena = _utils.client(service_name="athena", session=boto3_session)

    response: "GetCalculationExecutionStatusResponseTypeDef" = client_athena.get_calculation_execution_status(
        CalculationExecutionId=calculation_execution_id
    )
    state: str = response["Status"]["State"]

    while state not in _CALCULATION_EXECUTION_FINAL_STATES:
        time.sleep(athena_calculation_execution_wait_polling_delay)
        response = client_athena.get_calculation_execution_status(CalculationExecutionId=calculation_execution_id)
        state = response["Status"]["State"]
    _logger.debug("Calculation execution state: %s", state)
    _logger.debug("Calculation execution state change reason: %s", response["Status"].get("StateChangeReason"))
    if state in ["CANCELED", "FAILED"]:
        raise exceptions.CalculationFailed(response["Status"].get("StateChangeReason"))
    return response


def _get_calculation_execution_results(
    calculation_execution_id: str,
    boto3_session: boto3.Session | None = None,
) -> dict[str, Any]:
    client_athena = _utils.client(service_name="athena", session=boto3_session)

    _wait_calculation_execution(
        calculation_execution_id=calculation_execution_id,
        boto3_session=boto3_session,
    )

    response: "GetCalculationExecutionResponseTypeDef" = client_athena.get_calculation_execution(
        CalculationExecutionId=calculation_execution_id,
    )
    return cast(Dict[str, Any], response)


def create_spark_session(
    workgroup: str,
    coordinator_dpu_size: int = 1,
    max_concurrent_dpus: int = 5,
    default_executor_dpu_size: int = 1,
    additional_configs: dict[str, Any] | None = None,
    spark_properties: dict[str, Any] | None = None,
    notebook_version: str | None = None,
    idle_timeout: int = 15,
    boto3_session: boto3.Session | None = None,
) -> str:
    """
    Create session and wait until ready to accept calculations.

    Parameters
    ----------
    workgroup
        Athena workgroup name. Must be Spark-enabled.
    coordinator_dpu_size
        The number of DPUs to use for the coordinator. A coordinator is a special executor that orchestrates
        processing work and manages other executors in a notebook session. The default is 1.
    max_concurrent_dpus
        The maximum number of DPUs that can run concurrently. The default is 5.
    default_executor_dpu_size
        The default number of DPUs to use for executors. The default is 1.
    additional_configs
        Contains additional engine parameter mappings in the form of key-value pairs.
    spark_properties
        Contains SparkProperties in the form of key-value pairs.Specifies custom jar files and Spark properties
        for use cases like cluster encryption, table formats, and general Spark tuning.
    notebook_version
        The notebook version. This value is supplied automatically for notebook sessions in the Athena console and is not required for programmatic session access.
        The only valid notebook version is Athena notebook version 1. If you specify a value for NotebookVersion, you must also specify a value for NotebookId
    idle_timeout
        The idle timeout in minutes for the session. The default is 15.
    boto3_session
        The default boto3 session will be used if **boto3_session** receive ``None``.

    Returns
    -------
        Session ID

    Examples
    --------
    >>> import awswrangler as wr
    >>> df = wr.athena.create_spark_session(workgroup="...", max_concurrent_dpus=10)

    """
    client_athena = _utils.client(service_name="athena", session=boto3_session)
    engine_configuration: "EngineConfigurationTypeDef" = {
        "CoordinatorDpuSize": coordinator_dpu_size,
        "MaxConcurrentDpus": max_concurrent_dpus,
        "DefaultExecutorDpuSize": default_executor_dpu_size,
    }
    if additional_configs:
        engine_configuration["AdditionalConfigs"] = additional_configs
    if spark_properties:
        engine_configuration["SparkProperties"] = spark_properties
    kwargs: Any = {"SessionIdleTimeoutInMinutes": idle_timeout}
    if notebook_version:
        kwargs["NotebookVersion"] = notebook_version
    response = client_athena.start_session(
        WorkGroup=workgroup,
        EngineConfiguration=engine_configuration,
        **kwargs,
    )
    _logger.info("Session info:\n%s", response)
    session_id: str = response["SessionId"]
    # Wait for the session to reach IDLE state to be able to accept calculations
    _wait_session(
        session_id=session_id,
        boto3_session=boto3_session,
    )
    return session_id


def run_spark_calculation(
    code: str,
    workgroup: str,
    session_id: str | None = None,
    coordinator_dpu_size: int = 1,
    max_concurrent_dpus: int = 5,
    default_executor_dpu_size: int = 1,
    additional_configs: dict[str, Any] | None = None,
    spark_properties: dict[str, Any] | None = None,
    notebook_version: str | None = None,
    idle_timeout: int = 15,
    boto3_session: boto3.Session | None = None,
) -> dict[str, Any]:
    """
    Execute Spark Calculation and wait for completion.

    Parameters
    ----------
    code
        A string that contains the code for the calculation.
    workgroup
        Athena workgroup name. Must be Spark-enabled.
    session_id
        The session id. If not passed, a session will be started.
    coordinator_dpu_size
        The number of DPUs to use for the coordinator. A coordinator is a special executor that orchestrates
        processing work and manages other executors in a notebook session. The default is 1.
    max_concurrent_dpus
        The maximum number of DPUs that can run concurrently. The default is 5.
    default_executor_dpu_size
        The default number of DPUs to use for executors. The default is 1.
    additional_configs
        Contains additional engine parameter mappings in the form of key-value pairs.
    spark_properties
        Contains SparkProperties in the form of key-value pairs.Specifies custom jar files and Spark properties
        for use cases like cluster encryption, table formats, and general Spark tuning.
    notebook_version
        The notebook version. This value is supplied automatically for notebook sessions in the Athena console and is not required for programmatic session access.
        The only valid notebook version is Athena notebook version 1. If you specify a value for NotebookVersion, you must also specify a value for NotebookId
    idle_timeout
        The idle timeout in minutes for the session. The default is 15.
    boto3_session
        The default boto3 session will be used if **boto3_session** receive ``None``.

    Returns
    -------
        Calculation response

    Examples
    --------
    >>> import awswrangler as wr
    >>> df = wr.athena.run_spark_calculation(
    ...     code="print(spark)",
    ...     workgroup="...",
    ... )

    """
    client_athena = _utils.client(service_name="athena", session=boto3_session)

    session_id = (
        create_spark_session(
            workgroup=workgroup,
            coordinator_dpu_size=coordinator_dpu_size,
            max_concurrent_dpus=max_concurrent_dpus,
            default_executor_dpu_size=default_executor_dpu_size,
            additional_configs=additional_configs,
            spark_properties=spark_properties,
            notebook_version=notebook_version,
            idle_timeout=idle_timeout,
            boto3_session=boto3_session,
        )
        if not session_id
        else session_id
    )

    response = client_athena.start_calculation_execution(
        SessionId=session_id,
        CodeBlock=code,
    )
    _logger.info("Calculation execution info:\n%s", response)

    return _get_calculation_execution_results(
        calculation_execution_id=response["CalculationExecutionId"],
        boto3_session=boto3_session,
    )
