python/pyspark/sql/connect/shell/progress.py (136 lines of code) (raw):

# # Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """Implementation of a progress bar that is displayed while a query is running.""" import abc from dataclasses import dataclass import time import sys import typing from types import TracebackType from typing import Iterable, Any from pyspark.sql.connect.proto import ExecutePlanResponse try: from IPython.utils.terminal import get_terminal_size except ImportError: def get_terminal_size(defaultx: Any = None, defaulty: Any = None) -> Any: return (80, 25) from pyspark.sql.connect.shell import progress_bar_enabled @dataclass class StageInfo: stage_id: int num_tasks: int num_completed_tasks: int num_bytes_read: int done: bool class ProgressHandler(abc.ABC): @abc.abstractmethod def __call__( self, stages: typing.Optional[Iterable[StageInfo]], inflight_tasks: int, operation_id: typing.Optional[str], done: bool, ) -> None: pass def from_proto( proto: ExecutePlanResponse.ExecutionProgress, ) -> typing.Tuple[Iterable[StageInfo], int]: result = [] for stage in proto.stages: result.append( StageInfo( stage_id=stage.stage_id, num_tasks=stage.num_tasks, num_completed_tasks=stage.num_completed_tasks, num_bytes_read=stage.input_bytes_read, done=stage.done, ) ) return (result, proto.num_inflight_tasks) class Progress: """This is a small helper class to visualize a textual progress bar. he interface is very simple and assumes that nothing else prints to the standard output.""" SI_BYTE_SIZES = (1 << 60, 1 << 50, 1 << 40, 1 << 30, 1 << 20, 1 << 10, 1) SI_BYTE_SUFFIXES = ("EiB", "PiB", "TiB", "GiB", "MiB", "KiB", "B") def __init__( self, char: str = "*", min_width: int = 80, output: typing.IO = sys.stdout, enabled: bool = False, handlers: Iterable[ProgressHandler] = [], operation_id: typing.Optional[str] = None, ) -> None: """ Constructs a new Progress bar. The progress bar is typically used in the blocking query execution path to process the execution progress methods from the server. Parameters ---------- char : str The Default character to be used for printing the bar. min_width : numeric The minimum width of the progress bar output : file The output device to write the progress bar to. enabled : bool Whether the progress bar printing should be enabled or not. handlers : list of ProgressHandler A list of handlers that will be called when the progress bar is updated. """ self._ticks: typing.Optional[int] = None self._tick: typing.Optional[int] = None x, y = get_terminal_size() self._min_width = min_width self._char = char self._width = max(min(min_width, x), self._min_width) self._max_printed = 0 self._started = time.time() self._enabled = enabled or progress_bar_enabled() self._bytes_read = 0 self._out = output self._running = 0 self._handlers = handlers self._stages: Iterable[StageInfo] = [] self._operation_id = operation_id def _notify(self, done: bool = False) -> None: for handler in self._handlers: handler( stages=self._stages, inflight_tasks=self._running, operation_id=self._operation_id, done=done, ) def __enter__(self) -> "Progress": return self def __exit__( self, exc_type: typing.Optional[typing.Type[BaseException]], exception: typing.Optional[BaseException], exc_tb: typing.Optional[TracebackType], ) -> typing.Any: self.finish() return False def update_ticks( self, stages: Iterable[StageInfo], inflight_tasks: int, operation_id: typing.Optional[str] = None, ) -> None: """This method is called from the execution to update the progress bar with a new total tick counter and the current position. This is necessary in case new stages get added with new tasks and so the total task number will be updated as well. Parameters ---------- stages : list A list of StageInfo objects reporting progress in each stage. inflight_tasks : int The number of tasks that are currently running. """ if self._operation_id is None or len(self._operation_id) == 0: self._operation_id = operation_id total_tasks = sum(map(lambda x: x.num_tasks, stages)) completed_tasks = sum(map(lambda x: x.num_completed_tasks, stages)) if total_tasks > 0: self._ticks = total_tasks self._tick = completed_tasks self._bytes_read = sum(map(lambda x: x.num_bytes_read, stages)) if self._tick is not None and self._tick >= 0: self.output() self._running = inflight_tasks self._stages = stages self._notify(False) def finish(self) -> None: """Clear the last line. Called when the processing is done.""" self._notify(True) if self._enabled: print("\r" + " " * self._max_printed, end="", flush=True, file=self._out) print("\r", end="", flush=True, file=self._out) def output(self) -> None: """Writes the progress bar out.""" if self._enabled and self._tick is not None and self._ticks is not None: val = int((self._tick / float(self._ticks)) * self._width) bar = self._char * val + "-" * (self._width - val) percent_complete = (self._tick / self._ticks) * 100 elapsed = int(time.time() - self._started) scanned = self._bytes_to_string(self._bytes_read) running = self._running buffer = ( f"\r[{bar}] {percent_complete:.2f}% Complete " f"({running} Tasks running, {elapsed}s, Scanned {scanned})" ) self._max_printed = max(len(buffer), self._max_printed) print(buffer, end="", flush=True, file=self._out) @staticmethod def _bytes_to_string(size: int) -> str: """Helper method to convert a numeric bytes value into a human-readable representation""" i = 0 while i < len(Progress.SI_BYTE_SIZES) - 1 and size < 2 * Progress.SI_BYTE_SIZES[i]: i += 1 result = float(size) / Progress.SI_BYTE_SIZES[i] return f"{result:.1f} {Progress.SI_BYTE_SUFFIXES[i]}"