python/pyspark/sql/connect/shell/progress.py (136 lines of code) (raw):
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Implementation of a progress bar that is displayed while a query is running."""
import abc
from dataclasses import dataclass
import time
import sys
import typing
from types import TracebackType
from typing import Iterable, Any
from pyspark.sql.connect.proto import ExecutePlanResponse
try:
from IPython.utils.terminal import get_terminal_size
except ImportError:
def get_terminal_size(defaultx: Any = None, defaulty: Any = None) -> Any:
return (80, 25)
from pyspark.sql.connect.shell import progress_bar_enabled
@dataclass
class StageInfo:
stage_id: int
num_tasks: int
num_completed_tasks: int
num_bytes_read: int
done: bool
class ProgressHandler(abc.ABC):
@abc.abstractmethod
def __call__(
self,
stages: typing.Optional[Iterable[StageInfo]],
inflight_tasks: int,
operation_id: typing.Optional[str],
done: bool,
) -> None:
pass
def from_proto(
proto: ExecutePlanResponse.ExecutionProgress,
) -> typing.Tuple[Iterable[StageInfo], int]:
result = []
for stage in proto.stages:
result.append(
StageInfo(
stage_id=stage.stage_id,
num_tasks=stage.num_tasks,
num_completed_tasks=stage.num_completed_tasks,
num_bytes_read=stage.input_bytes_read,
done=stage.done,
)
)
return (result, proto.num_inflight_tasks)
class Progress:
"""This is a small helper class to visualize a textual progress bar.
he interface is very simple and assumes that nothing else prints to the
standard output."""
SI_BYTE_SIZES = (1 << 60, 1 << 50, 1 << 40, 1 << 30, 1 << 20, 1 << 10, 1)
SI_BYTE_SUFFIXES = ("EiB", "PiB", "TiB", "GiB", "MiB", "KiB", "B")
def __init__(
self,
char: str = "*",
min_width: int = 80,
output: typing.IO = sys.stdout,
enabled: bool = False,
handlers: Iterable[ProgressHandler] = [],
operation_id: typing.Optional[str] = None,
) -> None:
"""
Constructs a new Progress bar. The progress bar is typically used in
the blocking query execution path to process the execution progress
methods from the server.
Parameters
----------
char : str
The Default character to be used for printing the bar.
min_width : numeric
The minimum width of the progress bar
output : file
The output device to write the progress bar to.
enabled : bool
Whether the progress bar printing should be enabled or not.
handlers : list of ProgressHandler
A list of handlers that will be called when the progress bar is updated.
"""
self._ticks: typing.Optional[int] = None
self._tick: typing.Optional[int] = None
x, y = get_terminal_size()
self._min_width = min_width
self._char = char
self._width = max(min(min_width, x), self._min_width)
self._max_printed = 0
self._started = time.time()
self._enabled = enabled or progress_bar_enabled()
self._bytes_read = 0
self._out = output
self._running = 0
self._handlers = handlers
self._stages: Iterable[StageInfo] = []
self._operation_id = operation_id
def _notify(self, done: bool = False) -> None:
for handler in self._handlers:
handler(
stages=self._stages,
inflight_tasks=self._running,
operation_id=self._operation_id,
done=done,
)
def __enter__(self) -> "Progress":
return self
def __exit__(
self,
exc_type: typing.Optional[typing.Type[BaseException]],
exception: typing.Optional[BaseException],
exc_tb: typing.Optional[TracebackType],
) -> typing.Any:
self.finish()
return False
def update_ticks(
self,
stages: Iterable[StageInfo],
inflight_tasks: int,
operation_id: typing.Optional[str] = None,
) -> None:
"""This method is called from the execution to update the progress bar with a new total
tick counter and the current position. This is necessary in case new stages get added with
new tasks and so the total task number will be updated as well.
Parameters
----------
stages : list
A list of StageInfo objects reporting progress in each stage.
inflight_tasks : int
The number of tasks that are currently running.
"""
if self._operation_id is None or len(self._operation_id) == 0:
self._operation_id = operation_id
total_tasks = sum(map(lambda x: x.num_tasks, stages))
completed_tasks = sum(map(lambda x: x.num_completed_tasks, stages))
if total_tasks > 0:
self._ticks = total_tasks
self._tick = completed_tasks
self._bytes_read = sum(map(lambda x: x.num_bytes_read, stages))
if self._tick is not None and self._tick >= 0:
self.output()
self._running = inflight_tasks
self._stages = stages
self._notify(False)
def finish(self) -> None:
"""Clear the last line. Called when the processing is done."""
self._notify(True)
if self._enabled:
print("\r" + " " * self._max_printed, end="", flush=True, file=self._out)
print("\r", end="", flush=True, file=self._out)
def output(self) -> None:
"""Writes the progress bar out."""
if self._enabled and self._tick is not None and self._ticks is not None:
val = int((self._tick / float(self._ticks)) * self._width)
bar = self._char * val + "-" * (self._width - val)
percent_complete = (self._tick / self._ticks) * 100
elapsed = int(time.time() - self._started)
scanned = self._bytes_to_string(self._bytes_read)
running = self._running
buffer = (
f"\r[{bar}] {percent_complete:.2f}% Complete "
f"({running} Tasks running, {elapsed}s, Scanned {scanned})"
)
self._max_printed = max(len(buffer), self._max_printed)
print(buffer, end="", flush=True, file=self._out)
@staticmethod
def _bytes_to_string(size: int) -> str:
"""Helper method to convert a numeric bytes value into a human-readable representation"""
i = 0
while i < len(Progress.SI_BYTE_SIZES) - 1 and size < 2 * Progress.SI_BYTE_SIZES[i]:
i += 1
result = float(size) / Progress.SI_BYTE_SIZES[i]
return f"{result:.1f} {Progress.SI_BYTE_SUFFIXES[i]}"