in src/datasets/utils/tf_utils.py [0:0]
def __iter__(self):
# Make sure we only spawn workers if they have work to do
num_workers = min(self.num_workers, int(ceil(len(self.dataset) / self.batch_size)))
# Do the shuffling in iter so that it's done at the start of each epoch
per_worker_batches, final_batch, final_batch_worker = self.distribute_batches(
self.dataset, self.batch_size, self.drop_remainder, num_workers, self.shuffle
)
ctx = get_context("spawn")
names = []
shape_arrays = []
workers = []
array_ready_events = [ctx.Event() for _ in range(num_workers)]
array_loaded_events = [ctx.Event() for _ in range(num_workers)]
base_args = {
"dataset": self.dataset,
"cols_to_retain": self.cols_to_retain,
"collate_fn": self.collate_fn,
"collate_fn_args": self.collate_fn_args,
"columns_to_np_types": self.columns_to_np_types,
"columns_to_ranks": self.columns_to_ranks,
"string_columns": self.string_columns,
}
with SharedMemoryContext() as shm_ctx:
for i in range(num_workers):
worker_random_id = str(uuid4())
worker_name = f"dw_{i}_{worker_random_id}"[:10]
names.append(worker_name)
worker_shape_arrays = {
col: shm_ctx.get_array(f"{worker_name}_{col}_shape", shape=(rank,), dtype=np.int64, create=True)
for col, rank in self.columns_to_ranks.items()
}
shape_arrays.append(worker_shape_arrays)
worker_indices = per_worker_batches[i]
if i == final_batch_worker and final_batch is not None:
final_batch_arg = final_batch
else:
final_batch_arg = None
worker_kwargs = {
"worker_name": worker_name,
"indices": worker_indices,
"extra_batch": final_batch_arg,
"array_ready_event": array_ready_events[i],
"array_loaded_event": array_loaded_events[i],
**base_args,
}
worker = ctx.Process(target=self.worker_loop, kwargs=worker_kwargs, daemon=True)
worker.start()
workers.append(worker)
end_signal_received = False
while not end_signal_received:
for i in range(num_workers):
if not array_ready_events[i].wait(timeout=60):
raise TimeoutError("Data loading worker timed out!")
array_ready_events[i].clear()
array_shapes = shape_arrays[i]
if any(np.any(shape < 0) for shape in array_shapes.values()):
# Child processes send negative array shapes to indicate
# that no more data is going to be sent
end_signal_received = True
break
# Matt: Because array shapes are variable we recreate the shared memory each iteration.
# I suspect repeatedly opening lots of shared memory is the bottleneck for the parent process.
# A future optimization, at the cost of some code complexity, could be to reuse shared memory
# between iterations, but this would require knowing in advance the maximum size, or having
# a system to only create a new memory block when a new maximum size is seen.
# Another potential optimization would be to figure out which memory copies are necessary,
# or whether we can yield objects straight out of shared memory.
with SharedMemoryContext() as batch_shm_ctx:
# This memory context only lasts long enough to copy everything out of the batch
arrays = {
col: batch_shm_ctx.get_array(
f"{names[i]}_{col}",
shape=shape,
dtype=self.columns_to_np_types[col],
create=False,
)
for col, shape in array_shapes.items()
}
# Copy everything out of shm because the memory
# will be unlinked by the child process at some point
arrays = {col: np.copy(arr) for col, arr in arrays.items()}
# Now we convert any unicode char arrays to strings
for string_col in self.string_columns:
arrays[string_col] = (
arrays[string_col].view(f"U{arrays[string_col].shape[-1]}").squeeze(-1)
)
yield arrays
array_loaded_events[i].set()
# Now we just do some cleanup
# Shared memory is cleaned up by the context manager, so we just make sure workers finish
for worker in workers:
worker.join()