in src/datasets/utils/py_utils.py [0:0]
def _single_map_nested(args):
"""Apply a function recursively to each element of a nested data struct."""
function, data_struct, batched, batch_size, types, rank, disable_tqdm, desc = args
# Singleton first to spare some computation
if not isinstance(data_struct, dict) and not isinstance(data_struct, types):
if batched:
return function([data_struct])[0]
else:
return function(data_struct)
if (
batched
and not isinstance(data_struct, dict)
and isinstance(data_struct, types)
and all(not isinstance(v, (dict, types)) for v in data_struct)
):
return [mapped_item for batch in iter_batched(data_struct, batch_size) for mapped_item in function(batch)]
# Reduce logging to keep things readable in multiprocessing with tqdm
if rank is not None and logging.get_verbosity() < logging.WARNING:
logging.set_verbosity_warning()
# Print at least one thing to fix tqdm in notebooks in multiprocessing
# see https://github.com/tqdm/tqdm/issues/485#issuecomment-473338308
if rank is not None and not disable_tqdm and any("notebook" in tqdm_cls.__name__ for tqdm_cls in tqdm.__mro__):
print(" ", end="", flush=True)
# Loop over single examples or batches and write to buffer/file if examples are to be updated
pbar_iterable = data_struct.items() if isinstance(data_struct, dict) else data_struct
pbar_desc = (desc + " " if desc is not None else "") + "#" + str(rank) if rank is not None else desc
with hf_tqdm(pbar_iterable, disable=disable_tqdm, position=rank, unit="obj", desc=pbar_desc) as pbar:
if isinstance(data_struct, dict):
return {
k: _single_map_nested((function, v, batched, batch_size, types, None, True, None)) for k, v in pbar
}
else:
mapped = [_single_map_nested((function, v, batched, batch_size, types, None, True, None)) for v in pbar]
if isinstance(data_struct, list):
return mapped
elif isinstance(data_struct, tuple):
return tuple(mapped)
else:
return np.array(mapped)