def _single_map_nested()

in src/datasets/utils/py_utils.py [0:0]


def _single_map_nested(args):
    """Apply a function recursively to each element of a nested data struct."""
    function, data_struct, batched, batch_size, types, rank, disable_tqdm, desc = args

    # Singleton first to spare some computation
    if not isinstance(data_struct, dict) and not isinstance(data_struct, types):
        if batched:
            return function([data_struct])[0]
        else:
            return function(data_struct)
    if (
        batched
        and not isinstance(data_struct, dict)
        and isinstance(data_struct, types)
        and all(not isinstance(v, (dict, types)) for v in data_struct)
    ):
        return [mapped_item for batch in iter_batched(data_struct, batch_size) for mapped_item in function(batch)]

    # Reduce logging to keep things readable in multiprocessing with tqdm
    if rank is not None and logging.get_verbosity() < logging.WARNING:
        logging.set_verbosity_warning()
    # Print at least one thing to fix tqdm in notebooks in multiprocessing
    # see https://github.com/tqdm/tqdm/issues/485#issuecomment-473338308
    if rank is not None and not disable_tqdm and any("notebook" in tqdm_cls.__name__ for tqdm_cls in tqdm.__mro__):
        print(" ", end="", flush=True)

    # Loop over single examples or batches and write to buffer/file if examples are to be updated
    pbar_iterable = data_struct.items() if isinstance(data_struct, dict) else data_struct
    pbar_desc = (desc + " " if desc is not None else "") + "#" + str(rank) if rank is not None else desc
    with hf_tqdm(pbar_iterable, disable=disable_tqdm, position=rank, unit="obj", desc=pbar_desc) as pbar:
        if isinstance(data_struct, dict):
            return {
                k: _single_map_nested((function, v, batched, batch_size, types, None, True, None)) for k, v in pbar
            }
        else:
            mapped = [_single_map_nested((function, v, batched, batch_size, types, None, True, None)) for v in pbar]
            if isinstance(data_struct, list):
                return mapped
            elif isinstance(data_struct, tuple):
                return tuple(mapped)
            else:
                return np.array(mapped)