dataflux_pytorch/dataflux_mapstyle_dataset.py (138 lines of code) (raw):
"""
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
import multiprocessing
import os
import warnings
import dataflux_core
from google.cloud import storage
from google.cloud.storage.retry import DEFAULT_RETRY
from torch.utils import data
from dataflux_pytorch._helper import _get_missing_permissions
MODIFIED_RETRY = DEFAULT_RETRY.with_deadline(100000.0).with_delay(
initial=1.0, multiplier=1.5, maximum=30.0)
FORK = "fork"
CREATE = "storage.objects.create"
DELETE = "storage.objects.delete"
class Config:
"""Customizable configuration to the DataFluxMapStyleDataset.
Attributes:
sort_listing_results: A boolean flag indicating if data listing results
will be alphabetically sorted. Default to False.
max_composite_object_size: An integer indicating a cap for the maximum
size of the composite object in bytes. Default to 100000000 = 100 MiB.
num_processes: The number of processes to be used in the Dataflux algorithms.
Default to the number of CPUs from the running environment.
prefix: The prefix that is used to list the objects in the bucket with.
The default is None which means it will list all the objects in the bucket.
max_listing_retries: An integer indicating the maximum number of retries
to attempt in case of any Python multiprocessing errors during
GCS objects listing. Default to 3.
disable_compose: A boolean flag indicating if compose download should be active.
Compose should be disabled for highly scaled implementations.
list_retry_config: A google API retry for Dataflux fast list operations. This allows
for retry backoff configuration.
download_retry_config: A google API retry for Dataflux download operations. This allows
for retry backoff configuration.
"""
def __init__(
self,
sort_listing_results: bool = False,
max_composite_object_size: int = 100000000,
num_processes: int = os.cpu_count(),
prefix: str = None,
max_listing_retries: int = 3,
threads_per_process: int = 1,
disable_compose: bool = False,
list_retry_config:
"google.api_core.retry.retry_unary.Retry" = MODIFIED_RETRY,
download_retry_config:
"google.api_core.retry.retry_unary.Retry" = MODIFIED_RETRY,
):
self.sort_listing_results = sort_listing_results
self.max_composite_object_size = max_composite_object_size
self.num_processes = num_processes
self.prefix = prefix
self.max_listing_retries = max_listing_retries
self.threads_per_process = threads_per_process
if disable_compose:
self.max_composite_object_size = 0
self.list_retry_config = list_retry_config
self.download_retry_config = download_retry_config
def data_format_default(data):
return data
class DataFluxMapStyleDataset(data.Dataset):
def __init__(
self,
project_name,
bucket_name,
config=Config(),
data_format_fn=data_format_default,
storage_client=None,
):
"""Initializes the DataFluxMapStyleDataset.
The initialization sets up the needed configuration and runs data
listing using the Dataflux algorithm.
Args:
project_name: The name of the GCP project.
bucket_name: The name of the GCS bucket that holds the objects to compose.
The Dataflux download algorithm uploads the the composed object to this bucket too.
destination_blob_name: The name of the composite object to be created.
config: A dataflux_mapstyle_dataset.Config object that includes configuration
customizations. If not specified, a default config with default parameters is created.
data_format_fn: A function that formats the downloaded bytes to the desired format.
If not specified, the default formatting function leaves the data as-is.
storage_client: The google.cloud.storage.Client object initiated with sufficient permission
to access the project and the bucket. If not specified, one will be created
when needed.
Returns:
None.
"""
super().__init__()
multiprocessing_start = multiprocessing.get_start_method(
allow_none=False)
if storage_client is not None and multiprocessing_start != FORK:
warnings.warn(
"Setting the storage client is not fully supported when multiprocessing starts with spawn or forkserver.",
UserWarning,
)
self.storage_client = storage_client
self.project_name = project_name
self.bucket_name = bucket_name
self.data_format_fn = data_format_fn
self.config = config
# If composed download is enabled and a storage_client was provided,
# check if the client has permissions to create and delete the
# composed object.
if storage_client is not None and self.config.max_composite_object_size:
missing_perm = _get_missing_permissions(
storage_client=self.storage_client,
bucket_name=self.bucket_name,
project_name=self.project_name,
required_perm=[CREATE, DELETE])
if missing_perm and len(missing_perm) > 0:
raise PermissionError(
f"Missing permissions {', '.join(missing_perm)} for composed download. To disable composed download set config.disable_compose=True or to enable composed download, grant missing permissions."
)
self.dataflux_download_optimization_params = (
dataflux_core.download.DataFluxDownloadOptimizationParams(
max_composite_object_size=self.config.max_composite_object_size
))
self.objects = self._list_GCS_blobs_with_retry()
def __len__(self):
return len(self.objects)
def __getitem__(self, idx):
if self.storage_client is None:
self.storage_client = storage.Client(project=self.project_name)
return self.data_format_fn(
dataflux_core.download.download_single(
storage_client=self.storage_client,
bucket_name=self.bucket_name,
object_name=self.objects[idx][0],
retry_config=self.config.download_retry_config,
))
def __getitems__(self, indices):
if self.storage_client is None:
self.storage_client = storage.Client(project=self.project_name)
return [
self.data_format_fn(bytes_content) for bytes_content in
dataflux_core.download.dataflux_download_threaded(
project_name=self.project_name,
bucket_name=self.bucket_name,
objects=[self.objects[idx] for idx in indices],
storage_client=self.storage_client,
dataflux_download_optimization_params=self.
dataflux_download_optimization_params,
threads=self.config.threads_per_process,
retry_config=self.config.download_retry_config,
)
]
def _list_GCS_blobs_with_retry(self):
"""Retries Dataflux Listing upon exceptions, up to the retries defined in self.config."""
error = None
listed_objects = []
for _ in range(self.config.max_listing_retries):
try:
lister = dataflux_core.fast_list.ListingController(
max_parallelism=self.config.num_processes,
project=self.project_name,
bucket=self.bucket_name,
sort_results=self.config.sort_listing_results,
prefix=self.config.prefix,
retry_config=self.config.list_retry_config,
)
# If the dataset was not initialized with an storage_client, ensure that we do not attach a client to the lister to avoid pickling errors (#58).
lister.client = self.storage_client
listed_objects = lister.run()
except Exception as e:
logging.error(
f"exception {str(e)} caught running Dataflux fast listing."
)
error = e
continue
# No exception -- we can immediately return the listed objects.
else:
return listed_objects
# Did not break the for loop, therefore all attempts
# raised an exception.
else:
raise error
def __getstate__(self):
# Copy the object's state from self.__dict__ which contains
# all our instance attributes. Use the dict.copy()
# method to avoid modifying the original state.
state = self.__dict__.copy()
# Remove the unpicklable entries.
del state['storage_client']
return state
def __setstate__(self, state):
# Restore instance attributes.
self.__dict__.update(state)
# Create the storage client.
self.storage_client = storage.Client(project=self.project_name)