dataflux_pytorch/dataflux_mapstyle_dataset.py (138 lines of code) (raw):

""" Copyright 2023 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at https://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ import logging import multiprocessing import os import warnings import dataflux_core from google.cloud import storage from google.cloud.storage.retry import DEFAULT_RETRY from torch.utils import data from dataflux_pytorch._helper import _get_missing_permissions MODIFIED_RETRY = DEFAULT_RETRY.with_deadline(100000.0).with_delay( initial=1.0, multiplier=1.5, maximum=30.0) FORK = "fork" CREATE = "storage.objects.create" DELETE = "storage.objects.delete" class Config: """Customizable configuration to the DataFluxMapStyleDataset. Attributes: sort_listing_results: A boolean flag indicating if data listing results will be alphabetically sorted. Default to False. max_composite_object_size: An integer indicating a cap for the maximum size of the composite object in bytes. Default to 100000000 = 100 MiB. num_processes: The number of processes to be used in the Dataflux algorithms. Default to the number of CPUs from the running environment. prefix: The prefix that is used to list the objects in the bucket with. The default is None which means it will list all the objects in the bucket. max_listing_retries: An integer indicating the maximum number of retries to attempt in case of any Python multiprocessing errors during GCS objects listing. Default to 3. disable_compose: A boolean flag indicating if compose download should be active. Compose should be disabled for highly scaled implementations. list_retry_config: A google API retry for Dataflux fast list operations. This allows for retry backoff configuration. download_retry_config: A google API retry for Dataflux download operations. This allows for retry backoff configuration. """ def __init__( self, sort_listing_results: bool = False, max_composite_object_size: int = 100000000, num_processes: int = os.cpu_count(), prefix: str = None, max_listing_retries: int = 3, threads_per_process: int = 1, disable_compose: bool = False, list_retry_config: "google.api_core.retry.retry_unary.Retry" = MODIFIED_RETRY, download_retry_config: "google.api_core.retry.retry_unary.Retry" = MODIFIED_RETRY, ): self.sort_listing_results = sort_listing_results self.max_composite_object_size = max_composite_object_size self.num_processes = num_processes self.prefix = prefix self.max_listing_retries = max_listing_retries self.threads_per_process = threads_per_process if disable_compose: self.max_composite_object_size = 0 self.list_retry_config = list_retry_config self.download_retry_config = download_retry_config def data_format_default(data): return data class DataFluxMapStyleDataset(data.Dataset): def __init__( self, project_name, bucket_name, config=Config(), data_format_fn=data_format_default, storage_client=None, ): """Initializes the DataFluxMapStyleDataset. The initialization sets up the needed configuration and runs data listing using the Dataflux algorithm. Args: project_name: The name of the GCP project. bucket_name: The name of the GCS bucket that holds the objects to compose. The Dataflux download algorithm uploads the the composed object to this bucket too. destination_blob_name: The name of the composite object to be created. config: A dataflux_mapstyle_dataset.Config object that includes configuration customizations. If not specified, a default config with default parameters is created. data_format_fn: A function that formats the downloaded bytes to the desired format. If not specified, the default formatting function leaves the data as-is. storage_client: The google.cloud.storage.Client object initiated with sufficient permission to access the project and the bucket. If not specified, one will be created when needed. Returns: None. """ super().__init__() multiprocessing_start = multiprocessing.get_start_method( allow_none=False) if storage_client is not None and multiprocessing_start != FORK: warnings.warn( "Setting the storage client is not fully supported when multiprocessing starts with spawn or forkserver.", UserWarning, ) self.storage_client = storage_client self.project_name = project_name self.bucket_name = bucket_name self.data_format_fn = data_format_fn self.config = config # If composed download is enabled and a storage_client was provided, # check if the client has permissions to create and delete the # composed object. if storage_client is not None and self.config.max_composite_object_size: missing_perm = _get_missing_permissions( storage_client=self.storage_client, bucket_name=self.bucket_name, project_name=self.project_name, required_perm=[CREATE, DELETE]) if missing_perm and len(missing_perm) > 0: raise PermissionError( f"Missing permissions {', '.join(missing_perm)} for composed download. To disable composed download set config.disable_compose=True or to enable composed download, grant missing permissions." ) self.dataflux_download_optimization_params = ( dataflux_core.download.DataFluxDownloadOptimizationParams( max_composite_object_size=self.config.max_composite_object_size )) self.objects = self._list_GCS_blobs_with_retry() def __len__(self): return len(self.objects) def __getitem__(self, idx): if self.storage_client is None: self.storage_client = storage.Client(project=self.project_name) return self.data_format_fn( dataflux_core.download.download_single( storage_client=self.storage_client, bucket_name=self.bucket_name, object_name=self.objects[idx][0], retry_config=self.config.download_retry_config, )) def __getitems__(self, indices): if self.storage_client is None: self.storage_client = storage.Client(project=self.project_name) return [ self.data_format_fn(bytes_content) for bytes_content in dataflux_core.download.dataflux_download_threaded( project_name=self.project_name, bucket_name=self.bucket_name, objects=[self.objects[idx] for idx in indices], storage_client=self.storage_client, dataflux_download_optimization_params=self. dataflux_download_optimization_params, threads=self.config.threads_per_process, retry_config=self.config.download_retry_config, ) ] def _list_GCS_blobs_with_retry(self): """Retries Dataflux Listing upon exceptions, up to the retries defined in self.config.""" error = None listed_objects = [] for _ in range(self.config.max_listing_retries): try: lister = dataflux_core.fast_list.ListingController( max_parallelism=self.config.num_processes, project=self.project_name, bucket=self.bucket_name, sort_results=self.config.sort_listing_results, prefix=self.config.prefix, retry_config=self.config.list_retry_config, ) # If the dataset was not initialized with an storage_client, ensure that we do not attach a client to the lister to avoid pickling errors (#58). lister.client = self.storage_client listed_objects = lister.run() except Exception as e: logging.error( f"exception {str(e)} caught running Dataflux fast listing." ) error = e continue # No exception -- we can immediately return the listed objects. else: return listed_objects # Did not break the for loop, therefore all attempts # raised an exception. else: raise error def __getstate__(self): # Copy the object's state from self.__dict__ which contains # all our instance attributes. Use the dict.copy() # method to avoid modifying the original state. state = self.__dict__.copy() # Remove the unpicklable entries. del state['storage_client'] return state def __setstate__(self, state): # Restore instance attributes. self.__dict__.update(state) # Create the storage client. self.storage_client = storage.Client(project=self.project_name)