#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Implementation of a progress bar that is displayed while a query is running."""
import abc
from dataclasses import dataclass
import time
import sys
import typing
from types import TracebackType
from typing import Iterable, Any

from pyspark.sql.connect.proto import ExecutePlanResponse

try:
    from IPython.utils.terminal import get_terminal_size
except ImportError:

    def get_terminal_size(defaultx: Any = None, defaulty: Any = None) -> Any:
        return (80, 25)


from pyspark.sql.connect.shell import progress_bar_enabled


@dataclass
class StageInfo:
    stage_id: int
    num_tasks: int
    num_completed_tasks: int
    num_bytes_read: int
    done: bool


class ProgressHandler(abc.ABC):
    @abc.abstractmethod
    def __call__(
        self,
        stages: typing.Optional[Iterable[StageInfo]],
        inflight_tasks: int,
        operation_id: typing.Optional[str],
        done: bool,
    ) -> None:
        pass


def from_proto(
    proto: ExecutePlanResponse.ExecutionProgress,
) -> typing.Tuple[Iterable[StageInfo], int]:
    result = []
    for stage in proto.stages:
        result.append(
            StageInfo(
                stage_id=stage.stage_id,
                num_tasks=stage.num_tasks,
                num_completed_tasks=stage.num_completed_tasks,
                num_bytes_read=stage.input_bytes_read,
                done=stage.done,
            )
        )
    return (result, proto.num_inflight_tasks)


class Progress:
    """This is a small helper class to visualize a textual progress bar.
    he interface is very simple and assumes that nothing else prints to the
    standard output."""

    SI_BYTE_SIZES = (1 << 60, 1 << 50, 1 << 40, 1 << 30, 1 << 20, 1 << 10, 1)
    SI_BYTE_SUFFIXES = ("EiB", "PiB", "TiB", "GiB", "MiB", "KiB", "B")

    def __init__(
        self,
        char: str = "*",
        min_width: int = 80,
        output: typing.IO = sys.stdout,
        enabled: bool = False,
        handlers: Iterable[ProgressHandler] = [],
        operation_id: typing.Optional[str] = None,
    ) -> None:
        """
        Constructs a new Progress bar. The progress bar is typically used in
        the blocking query execution path to process the execution progress
        methods from the server.

        Parameters
        ----------
        char : str
          The Default character to be used for printing the bar.
        min_width : numeric
          The minimum width of the progress bar
        output : file
          The output device to write the progress bar to.
        enabled : bool
          Whether the progress bar printing should be enabled or not.
        handlers : list of ProgressHandler
          A list of handlers that will be called when the progress bar is updated.
        """
        self._ticks: typing.Optional[int] = None
        self._tick: typing.Optional[int] = None
        x, y = get_terminal_size()
        self._min_width = min_width
        self._char = char
        self._width = max(min(min_width, x), self._min_width)
        self._max_printed = 0
        self._started = time.time()
        self._enabled = enabled or progress_bar_enabled()
        self._bytes_read = 0
        self._out = output
        self._running = 0
        self._handlers = handlers
        self._stages: Iterable[StageInfo] = []
        self._operation_id = operation_id

    def _notify(self, done: bool = False) -> None:
        for handler in self._handlers:
            handler(
                stages=self._stages,
                inflight_tasks=self._running,
                operation_id=self._operation_id,
                done=done,
            )

    def __enter__(self) -> "Progress":
        return self

    def __exit__(
        self,
        exc_type: typing.Optional[typing.Type[BaseException]],
        exception: typing.Optional[BaseException],
        exc_tb: typing.Optional[TracebackType],
    ) -> typing.Any:
        self.finish()
        return False

    def update_ticks(
        self,
        stages: Iterable[StageInfo],
        inflight_tasks: int,
        operation_id: typing.Optional[str] = None,
    ) -> None:
        """This method is called from the execution to update the progress bar with a new total
        tick counter and the current position. This is necessary in case new stages get added with
        new tasks and so the total task number will be updated as well.

        Parameters
        ----------
        stages : list
          A list of StageInfo objects reporting progress in each stage.
        inflight_tasks : int
          The number of tasks that are currently running.
        """
        if self._operation_id is None or len(self._operation_id) == 0:
            self._operation_id = operation_id

        total_tasks = sum(map(lambda x: x.num_tasks, stages))
        completed_tasks = sum(map(lambda x: x.num_completed_tasks, stages))
        if total_tasks > 0:
            self._ticks = total_tasks
            self._tick = completed_tasks
            self._bytes_read = sum(map(lambda x: x.num_bytes_read, stages))
            if self._tick is not None and self._tick >= 0:
                self.output()
            self._running = inflight_tasks
            self._stages = stages
            self._notify(False)

    def finish(self) -> None:
        """Clear the last line. Called when the processing is done."""
        self._notify(True)
        if self._enabled:
            print("\r" + " " * self._max_printed, end="", flush=True, file=self._out)
            print("\r", end="", flush=True, file=self._out)

    def output(self) -> None:
        """Writes the progress bar out."""
        if self._enabled and self._tick is not None and self._ticks is not None:
            val = int((self._tick / float(self._ticks)) * self._width)
            bar = self._char * val + "-" * (self._width - val)
            percent_complete = (self._tick / self._ticks) * 100
            elapsed = int(time.time() - self._started)
            scanned = self._bytes_to_string(self._bytes_read)
            running = self._running
            buffer = (
                f"\r[{bar}] {percent_complete:.2f}% Complete "
                f"({running} Tasks running, {elapsed}s, Scanned {scanned})"
            )
            self._max_printed = max(len(buffer), self._max_printed)
            print(buffer, end="", flush=True, file=self._out)

    @staticmethod
    def _bytes_to_string(size: int) -> str:
        """Helper method to convert a numeric bytes value into a human-readable representation"""
        i = 0
        while i < len(Progress.SI_BYTE_SIZES) - 1 and size < 2 * Progress.SI_BYTE_SIZES[i]:
            i += 1
        result = float(size) / Progress.SI_BYTE_SIZES[i]
        return f"{result:.1f} {Progress.SI_BYTE_SUFFIXES[i]}"
