def __iter__()

in dataflux_pytorch/benchmark/standalone_dataloader/standalone_dataloader.py [0:0]


    def __iter__(self):
        """
        Overriding the __iter__ function allows batch size to be used to sub-divide the contents of
        individual parquet files.
        """
        worker_info = data.get_worker_info()
        files_per_node = math.ceil(len(self.objects) / self.world_size)
        start_point = self.rank * files_per_node
        end_point = start_point + files_per_node
        if worker_info is None:
            print("Single-process data loading detected", flush=True)
            for bytes_content in dataflux_iterable_dataset.dataflux_core.download.dataflux_download_lazy(
                    project_name=self.project_name,
                    bucket_name=self.bucket_name,
                    objects=self.objects[start_point:end_point],
                    storage_client=self.storage_client,
                    dataflux_download_optimization_params=self.
                    dataflux_download_optimization_params,
                    retry_config=self.config.download_retry_config,
            ):
                table = self.data_format_fn(bytes_content)
                for batch in table.iter_batches(batch_size=self.batch_size,
                                                columns=self.columns):
                    yield from batch.to_pylist()
        else:
            # Multi-process data loading. Split the workload among workers.
            # Ref: https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset.
            # For the purpose of this example, this split is only performed at the file level.
            # This means that (num_nodes * num_workers) should be less than or equal to filecount.
            per_worker = int(
                math.ceil(files_per_node / float(worker_info.num_workers)))
            worker_id = worker_info.id
            start = worker_id * per_worker + start_point
            end = min(start + per_worker, end_point)
            max_logging.log(
                f"-----> Worker {self.rank}.{worker_id} downloading {self.objects[start:end]}\n"
            )
            for bytes_content in dataflux_iterable_dataset.dataflux_core.download.dataflux_download_lazy(
                    project_name=self.project_name,
                    bucket_name=self.bucket_name,
                    objects=self.objects[start:end],
                    storage_client=self.storage_client,
                    dataflux_download_optimization_params=self.
                    dataflux_download_optimization_params,
                    retry_config=self.config.download_retry_config,
            ):
                table = self.data_format_fn(bytes_content)
                for batch in table.iter_batches(batch_size=self.batch_size,
                                                columns=self.columns):
                    yield from batch.to_pylist()