pipeline/common/downloads.py (366 lines of code) (raw):

import gzip import io import json import os import time from contextlib import ExitStack, contextmanager from io import BufferedReader from pathlib import Path from typing import Any, Callable, Generator, Literal, Optional, Union from zipfile import ZipFile import requests from zstandard import ZstdCompressor, ZstdDecompressor from pipeline.common import format_bytes from pipeline.common.logging import get_logger logger = get_logger(__file__) class DownloadException(Exception): def __init__(self, msg: str): super().__init__(msg) def stream_download_to_file(url: str, destination: Union[str, Path]) -> None: """ Streams a download to a file, and retries several times if there are any failures. The destination file must not already exist. """ if os.path.exists(destination): raise DownloadException(f"That file already exists: {destination}") logger.info(f"Destination: {destination}") try: with open(destination, "wb") as file, DownloadChunkStreamer(url) as chunk_streamer: for chunk in chunk_streamer.download_chunks(): file.write(chunk) except DownloadException: Path(destination).unlink() raise def get_mocked_downloads_file_path(url: str) -> Optional[str]: """If there is a mocked download, get the path to the file, otherwise return None""" mocked_downloads_str = os.environ.get("MOCKED_DOWNLOADS") if not mocked_downloads_str: return None mocked_downloads = json.loads(mocked_downloads_str) if not isinstance(mocked_downloads, dict): raise DownloadException( "Expected the mocked downloads to be a json object mapping the URL to file path" ) source_file = mocked_downloads.get(url) if not source_file: print("MOCKED_DOWNLOADS:", mocked_downloads) raise DownloadException(f"Received a URL that was not in MOCKED_DOWNLOADS {url}") if not os.path.exists(source_file): raise DownloadException(f"The source file specified did not exist {source_file}") logger.info("Mocking a download.") logger.info(f" url: {url}") logger.info(f" file: {source_file}") return source_file def location_exists(location: str): """ Checks if a location (url or file path) exists. """ if location.startswith("http://") or location.startswith("https://"): response = requests.head(location, allow_redirects=True) return response.ok return os.path.exists(location) def attempt_mocked_request(url: str) -> Optional[BufferedReader]: """ If there are mocked download, use that. """ file_path = get_mocked_downloads_file_path(url) if file_path: return open(file_path, "rb") return None def get_download_size(url: str) -> int: """Get the total bytes of a file to download.""" mocked_file_path = get_mocked_downloads_file_path(url) if mocked_file_path: return os.path.getsize(mocked_file_path) response = requests.head(url, allow_redirects=True) size = response.headers.get("content-length", 0) return int(size) class RemoteDecodingLineStreamer: """ Base class to stream lines directly from a remote file. """ def __init__(self, url: str) -> None: self.url = url self.decoding_stream = None self.byte_chunk_stream = None self.line_stream = None def __enter__(self): mocked_request = attempt_mocked_request(self.url) if mocked_request: # We are in a test. logger.info(f"Using a mocked download: {self.url}") self.byte_chunk_stream = mocked_request self.decoding_stream = self.decode(self.byte_chunk_stream) else: self.byte_chunk_stream = DownloadChunkStreamer(self.url).__enter__() self.decoding_stream = self.decode(self.byte_chunk_stream) self.line_stream = io.TextIOWrapper(self.decoding_stream, encoding="utf-8") return self.line_stream def __exit__(self, _exc_type, _exc_val, _exc_tb): if self.line_stream: self.line_stream.close() if self.decoding_stream: self.decoding_stream.close() if self.byte_chunk_stream: self.byte_chunk_stream.close() def decode(self, byte_stream: Any) -> Any: # This byte stream requires no decoding, so just pass it on through. return byte_stream class RemoteGzipLineStreamer(RemoteDecodingLineStreamer): """ Stream lines directly from a remote gzip file. The line includes the newlines separator. Usage: with RemoteGzipLineStreamer(url) as lines: for line in lines: print(line) """ def decode(self, byte_stream): return gzip.GzipFile(fileobj=byte_stream) class RemoteZstdLineStreamer(RemoteDecodingLineStreamer): """ Stream lines directly from a remote zstd file. The line includes the newlines separator. Usage: with RemoteZstdLineStreamer(url) as lines: for line in lines: print(line) """ def decode(self, byte_stream): return ZstdDecompressor().stream_reader(byte_stream) class DownloadChunkStreamer(io.IOBase): """ Streams a download as chunks, and retries several times if there are any failures. This clas implements io.IOBase so it can be used as a file reader. Iterator over chunks directly: with DownloadChunkStreamer(url) as chunk_streamer: for chunk in chunk_streamer.download_chunks(): f.write(chunk) Or pass it as a file handle: with DownloadChunkStreamer(url) as f: gzip.GzipFile(fileobj=f) """ def __init__(self, url: str, total_retries=3, timeout_sec=10.0, wait_before_retry_sec=60.0): self.url = url self.response = None # How many retry attempts should there be, and how long to wait between retries. self.total_retries = total_retries self.wait_before_retry_sec = wait_before_retry_sec # How long to wait for a response to timeout? This is the time that no new data is received. self.timeout_sec = timeout_sec self.report_every = 0.05 # What percentage of the download to report updates? self.next_report_percent = self.report_every # The next report percentage. self.downloaded_bytes = 0 self.chunk_bytes = 8 * 1024 # The buffered `read` data. self.buffer = b"" # The Generator result of _download_chunks. self.chunk_iter: Optional[Generator[bytes, None, None]] = None def __enter__(self): """ On enter, kick off the download, and store the chunk iterator. This iterator handles the restarts for Requests. """ self.chunk_iter = self.download_chunks() return self def __exit__(self, exc_type, exc_val, exc_tb): self.close() def close(self): """ Close out the response, and cancel any iterators. """ if self.response: self.response.close() self.response = None self.chunk_iter = None def read(self, size=-1) -> bytes: """ This method implements the io.IOBase read method. It buffers the chunks until the `size` requirement is fulfilled. It is backed by the chunks_iter created by the download_chunks method. """ if not self.chunk_iter: # The chunk iterator was consumed. Return an empty byte object to indicate the download # is complete. return b"" if size < 0: # Load everything into the buffer, and return it. for chunk in self.chunk_iter: self.buffer += chunk result = self.buffer self.buffer = b"" return result # Load the buffer with requested amount of data to read. -1 indicates load everything. while len(self.buffer) < size: chunk = next(self.chunk_iter, None) if chunk: self.buffer += chunk else: # The stream ended. break # Return the requested read amount, and divide up the remaining buffer. result = self.buffer[:size] self.buffer = self.buffer[size:] return result def readable(self): return True def download_chunks(self) -> Generator[bytes, None, None]: """ This method is the generator that is responsible for running the request, and retrying when there is a failure. It yields the fixed size byte chunks, and exposes a generator to be consumed. This generator can be used directly in a for loop, or the entire class can be passed in as a file handle. """ next_report_percent = self.report_every total_bytes = 0 exception = None for retry in range(self.total_retries): if retry > 0: logger.error(f"Remaining retries: {self.total_retries - retry}") try: headers = {} if self.downloaded_bytes > 0: # Pick up the download from where it was before. headers = {"Range": f"bytes={self.downloaded_bytes}-"} self.response = requests.get( self.url, headers=headers, stream=True, timeout=self.timeout_sec ) self.response.raise_for_status() # Report the download size. if not total_bytes and "content-length" in self.response.headers: total_bytes = int(self.response.headers["content-length"]) logger.info(f"Download size: {total_bytes:,} bytes") for chunk in self.response.iter_content(chunk_size=self.chunk_bytes): if not chunk: continue self.downloaded_bytes += len(chunk) # Report the percentage downloaded every `report_every` percentage. if total_bytes and self.downloaded_bytes >= next_report_percent * total_bytes: logger.info( f"{self.downloaded_bytes / total_bytes * 100.0:.0f}% downloaded " f"({self.downloaded_bytes}/{total_bytes} bytes)" ) next_report_percent += self.report_every yield chunk # The download is complete. self.close() logger.info("100% downloaded - Download finished.") return except requests.exceptions.Timeout as error: logger.error(f"The connection timed out: {error}.") exception = error except requests.exceptions.RequestException as error: # The RequestException is the generic error that catches all classes of "requests" # errors. Don't attempt to be be smart about this, just attempt again until # the retries are done. logger.error(f"A download error occurred: {error}") exception = error # Close out the response on an error. It will be recreated when retrying. if self.response: self.response.close() self.response = None logger.info(f"Retrying in {self.wait_before_retry_sec} sec") time.sleep(self.wait_before_retry_sec) self.close() raise DownloadException("The download failed.") from exception def decode(self, byte_stream) -> Generator[bytes, None, None]: """Pass through the byte stream. This method can be specialized by child classes.""" return byte_stream @contextmanager def _read_lines_multiple_files( files: list[Union[str, Path]], encoding: str, path_in_archive: Optional[str], on_enter_location: Optional[Callable[[str], None]] = None, ) -> Generator[Generator[str, None, None], None, None]: """ Iterates through each line in multiple files, combining it into a single stream. """ stack = None def iter(stack: ExitStack): for file_path in files: logger.info(f"Reading lines from: {file_path}") lines = stack.enter_context( read_lines(file_path, path_in_archive, on_enter_location, encoding=encoding) ) yield from lines stack.close() try: stack = ExitStack() yield iter(stack) finally: if stack: stack.close() @contextmanager def _read_lines_single_file( location: Path | str, encoding: str, path_in_archive: Optional[str] = None, on_enter_location: Optional[Callable[[str], None]] = None, ) -> Generator[Generator[str, None, None], None, None]: """ A smart function to efficiently stream lines from a local or remote file. The location can either be a URL or a local file system path. It handles gzip, zst, and plain text files. Args: location - URL or file path path_in_archive - The path to a file in a zip archive on_enter_location - A lambda for when a new location is entered """ location = str(location) if on_enter_location: on_enter_location(location) if location.startswith("http://") or location.startswith("https://"): # If this is mocked for a test, use the locally mocked path. mocked_location = get_mocked_downloads_file_path(location) if mocked_location: location = mocked_location stack = ExitStack() try: if location.startswith("http://") or location.startswith("https://"): # This is a remote file. response = requests.head(location, allow_redirects=True) content_type = response.headers.get("Content-Type") if content_type == "application/gzip": yield stack.enter_context(RemoteGzipLineStreamer(location)) # type: ignore[reportReturnType] elif content_type == "application/zstd": yield stack.enter_context(RemoteZstdLineStreamer(location)) # type: ignore[reportReturnType] elif content_type == "application/zip": raise DownloadException("Streaming a zip from a remote location is supported.") elif content_type == "text/plain": yield stack.enter_context(RemoteDecodingLineStreamer(location)) # type: ignore[reportReturnType] elif location.endswith(".gz") or location.endswith(".gzip"): yield stack.enter_context(RemoteGzipLineStreamer(location)) # type: ignore[reportReturnType] elif location.endswith(".zst"): yield stack.enter_context(RemoteZstdLineStreamer(location)) # type: ignore[reportReturnType] else: # Treat as plain text. yield stack.enter_context(RemoteDecodingLineStreamer(location)) # type: ignore[reportReturnType] else: # noqa: PLR5501 # This is a local file. if location.endswith(".gz") or location.endswith(".gzip"): yield stack.enter_context(gzip.open(location, "rt", encoding=encoding)) # type: ignore[reportReturnType] elif location.endswith(".zst"): input_file = stack.enter_context(open(location, "rb")) zst_reader = stack.enter_context(ZstdDecompressor().stream_reader(input_file)) yield stack.enter_context(io.TextIOWrapper(zst_reader, encoding=encoding)) # type: ignore[reportReturnType] elif location.endswith(".zip"): if not path_in_archive: raise DownloadException("Expected a path into the zip file.") zip = stack.enter_context(ZipFile(location, "r")) if path_in_archive not in zip.namelist(): raise DownloadException( f"Path did not exist in the zip file: {path_in_archive}" ) file = stack.enter_context(zip.open(path_in_archive, "r")) yield stack.enter_context(io.TextIOWrapper(file, encoding=encoding)) # type: ignore[reportReturnType] else: # Treat as plain text. yield stack.enter_context(open(location, "rt", encoding=encoding)) # type: ignore[reportReturnType] finally: stack.close() def read_lines( location_or_locations: Union[Path, str, list[Union[str, Path]]], path_in_archive: Optional[str] = None, on_enter_location: Optional[Callable[[str], None]] = None, encoding="utf-8", ): """ A smart function to efficiently stream lines from a local or remote file. The location can either be a URL or a local file system path. It handles gzip, zst, and plain text files. It can also handle a list of files. Args: location_or_locations - A single URL or file path, or a list path_in_archive - The path to a file in a zip archive Usage: with read_lines("output.txt.gz") as lines: for line in lines: print(line) paths = [ "http://example.com/file.txt.gz", "path/to/file.txt.zst", ] with read_lines(paths) as lines: for line in lines: print(line) """ if isinstance(location_or_locations, list): return _read_lines_multiple_files( location_or_locations, encoding, path_in_archive, on_enter_location ) return _read_lines_single_file( location_or_locations, encoding, path_in_archive, on_enter_location ) @contextmanager def write_lines(path: Path | str, encoding="utf-8"): """ A smart function to create a context to write lines to a file. It works on .zst, .gz, and raw text files. It reads the extension to determine the file type. If writing out a raw text file, for instance a sample of a dataset that is just used for viewing, include a "byte order mark" so that the browser can properly detect the encoding. with write_lines("output.txt.gz") as output: output.write("writing a line\n") output.write("writing a second lines\n") """ stack = None try: path = str(path) stack = ExitStack() if path.endswith(".zst"): file = stack.enter_context(open(path, "wb")) compressor = stack.enter_context(ZstdCompressor().stream_writer(file)) yield stack.enter_context(io.TextIOWrapper(compressor, encoding=encoding)) elif path.endswith(".gz"): yield stack.enter_context(gzip.open(path, "wt", encoding=encoding)) else: yield stack.enter_context(open(path, "wt", encoding=encoding)) finally: if stack: stack.close() def count_lines(path: Path | str) -> int: """ Similar to wc -l, this counts the lines in a file. However, this command does so regardless of the compression strategy used on the file. """ with read_lines(path) as lines: return sum(1 for _ in lines) def is_file_empty(path: Path | str) -> bool: """ Attempts to read a line to determine if a file is empty or not. Works on local or remote files as well as compressed or uncompressed files. """ with read_lines(path) as lines: try: next(lines) return False except StopIteration: return True def get_file_size(location: Path | str) -> int: """Get the size of a file, whether it is remote or local.""" if isinstance(location, str) and ( location.startswith("http://") or location.startswith("https://") ): return get_download_size(location) return os.path.getsize(location) def get_human_readable_file_size(location: Path | str) -> tuple[str, int]: """Get the size of a file in a human-readable string, and the numeric bytes.""" bytes = get_file_size(location) return format_bytes(bytes), bytes def compress_file( path: Path | str, keep_original: bool = True, compressed_path: Optional[Path | str] = None, compression: Literal["zst", "gz"] = "zst", ) -> Path: """ Compresses a file to .zst or .gz format. It returns the path of the compressed file. "zst" is the preferred compression scheme. """ path = Path(path) if compression == "zst": if not compressed_path: compressed_path = Path(str(path) + ".zst") cctx = ZstdCompressor() with open(path, "rb") as infile: with open(compressed_path, "wb") as outfile: outfile.write(cctx.compress(infile.read())) elif compression == "gz": if not compressed_path: compressed_path = Path(str(path) + ".gz") with open(path, "rb") as infile: with gzip.open(compressed_path, "wb") as outfile: outfile.write(infile.read()) else: raise ValueError(f"Unsupported compression format: {compression}") if not keep_original: # Delete the original file path.unlink() return Path(compressed_path) def decompress_file( path: Union[str, Path], keep_original: bool = True, decompressed_path: Optional[Union[str, Path]] = None, ) -> Path: """ Decompresses a .gz or .zst file. It returns the path of the decompressed file. """ path = Path(path) if decompressed_path: decompressed_path = Path(decompressed_path) else: # Remove the original suffix decompressed_path = path.with_suffix("") with ExitStack() as stack: decompressed_file = stack.enter_context(decompressed_path.open("wb")) if path.suffix == ".gz": compressed_file = stack.enter_context(gzip.open(str(path), "rb")) decompressed_file.write(compressed_file.read()) while True: # Write the data out in chunks so that all of the it doesn't need to be # into memory. chunk = compressed_file.read(10_240) if not chunk: break decompressed_file.write(chunk) elif path.suffix == ".zst": compressed_file = stack.enter_context(open(path, "rb")) for chunk in ZstdDecompressor().read_to_iter(compressed_file): # Write the data out in chunks so that all of the it doesn't need to be # into memory. decompressed_file.write(chunk) else: raise ValueError(f"Unsupported file extension: {path.suffix}") if not keep_original: # Delete the original file path.unlink() return decompressed_path