def __iter__()

in src/datasets/utils/tf_utils.py [0:0]


    def __iter__(self):
        # Make sure we only spawn workers if they have work to do
        num_workers = min(self.num_workers, int(ceil(len(self.dataset) / self.batch_size)))
        # Do the shuffling in iter so that it's done at the start of each epoch
        per_worker_batches, final_batch, final_batch_worker = self.distribute_batches(
            self.dataset, self.batch_size, self.drop_remainder, num_workers, self.shuffle
        )
        ctx = get_context("spawn")
        names = []
        shape_arrays = []
        workers = []
        array_ready_events = [ctx.Event() for _ in range(num_workers)]
        array_loaded_events = [ctx.Event() for _ in range(num_workers)]

        base_args = {
            "dataset": self.dataset,
            "cols_to_retain": self.cols_to_retain,
            "collate_fn": self.collate_fn,
            "collate_fn_args": self.collate_fn_args,
            "columns_to_np_types": self.columns_to_np_types,
            "columns_to_ranks": self.columns_to_ranks,
            "string_columns": self.string_columns,
        }
        with SharedMemoryContext() as shm_ctx:
            for i in range(num_workers):
                worker_random_id = str(uuid4())
                worker_name = f"dw_{i}_{worker_random_id}"[:10]
                names.append(worker_name)

                worker_shape_arrays = {
                    col: shm_ctx.get_array(f"{worker_name}_{col}_shape", shape=(rank,), dtype=np.int64, create=True)
                    for col, rank in self.columns_to_ranks.items()
                }
                shape_arrays.append(worker_shape_arrays)

                worker_indices = per_worker_batches[i]
                if i == final_batch_worker and final_batch is not None:
                    final_batch_arg = final_batch
                else:
                    final_batch_arg = None
                worker_kwargs = {
                    "worker_name": worker_name,
                    "indices": worker_indices,
                    "extra_batch": final_batch_arg,
                    "array_ready_event": array_ready_events[i],
                    "array_loaded_event": array_loaded_events[i],
                    **base_args,
                }
                worker = ctx.Process(target=self.worker_loop, kwargs=worker_kwargs, daemon=True)
                worker.start()
                workers.append(worker)

            end_signal_received = False
            while not end_signal_received:
                for i in range(num_workers):
                    if not array_ready_events[i].wait(timeout=60):
                        raise TimeoutError("Data loading worker timed out!")
                    array_ready_events[i].clear()
                    array_shapes = shape_arrays[i]
                    if any(np.any(shape < 0) for shape in array_shapes.values()):
                        # Child processes send negative array shapes to indicate
                        # that no more data is going to be sent
                        end_signal_received = True
                        break
                    # Matt: Because array shapes are variable we recreate the shared memory each iteration.
                    #       I suspect repeatedly opening lots of shared memory is the bottleneck for the parent process.
                    #       A future optimization, at the cost of some code complexity, could be to reuse shared memory
                    #       between iterations, but this would require knowing in advance the maximum size, or having
                    #       a system to only create a new memory block when a new maximum size is seen.
                    #       Another potential optimization would be to figure out which memory copies are necessary,
                    #       or whether we can yield objects straight out of shared memory.
                    with SharedMemoryContext() as batch_shm_ctx:
                        # This memory context only lasts long enough to copy everything out of the batch
                        arrays = {
                            col: batch_shm_ctx.get_array(
                                f"{names[i]}_{col}",
                                shape=shape,
                                dtype=self.columns_to_np_types[col],
                                create=False,
                            )
                            for col, shape in array_shapes.items()
                        }
                        # Copy everything out of shm because the memory
                        # will be unlinked by the child process at some point
                        arrays = {col: np.copy(arr) for col, arr in arrays.items()}
                        # Now we convert any unicode char arrays to strings
                        for string_col in self.string_columns:
                            arrays[string_col] = (
                                arrays[string_col].view(f"U{arrays[string_col].shape[-1]}").squeeze(-1)
                            )
                    yield arrays
                    array_loaded_events[i].set()
            # Now we just do some cleanup
            # Shared memory is cleaned up by the context manager, so we just make sure workers finish
            for worker in workers:
                worker.join()