in src/datasets/iterable_dataset.py [0:0]
def _iter(self):
current_idx = self._state_dict["previous_state_example_idx"] if self._state_dict else 0
if self._state_dict and self._state_dict["previous_state"]:
self.ex_iterable.load_state_dict(self._state_dict["previous_state"])
num_examples_to_skip = self._state_dict["num_examples_since_previous_state"]
else:
num_examples_to_skip = 0
iterator = iter(self.ex_iterable)
# We use the same logic as in Dataset.map, but with less features/formatting
# since they're handled by FormattedExamplesIterable
if self.formatting:
formatter = get_formatter(self.formatting.format_type)
format_dict = formatter.recursive_tensorize if isinstance(formatter, TensorFormatter) else None
else:
format_dict = None
def iter_batched_inputs():
nonlocal current_idx
for key, example in iterator:
# If `batched`, first build the batch, if `batch_size` is None or <=0, then the batch is the whole dataset
iterator_batch = (
iterator
if self.batch_size is None or self.batch_size <= 0
else islice(iterator, self.batch_size - 1)
)
key_examples_list = [(key, example)] + list(iterator_batch)
keys, examples = zip(*key_examples_list)
# the new key is the concatenation of the examples keys from the batch
key = "_".join(str(key) for key in keys)
if (
self.drop_last_batch
and self.batch_size is not None
and self.batch_size > 0
and len(examples) < self.batch_size
): # ignore last batch
return
batch = _examples_to_batch(examples)
# we need to format here in case we need to stack tensors together
batch = format_dict(batch) if format_dict else batch
indices = [current_idx + i for i in range(len(key_examples_list))]
current_idx += len(indices)
yield indices, (key, batch)
def iter_inputs():
nonlocal current_idx
for key, example in iterator:
# If not batched, we can apply the transform and yield the example directly
# first copy the example, since we might drop some keys
example = dict(example)
# no need to do formatting here
current_idx += 1
yield current_idx - 1, (key, example)
def validate_function_output(processed_inputs):
if self.batched and processed_inputs:
first_col = next(iter(processed_inputs))
bad_cols = [
col for col in processed_inputs if len(processed_inputs[col]) != len(processed_inputs[first_col])
]
if bad_cols:
raise ValueError(
f"Column lengths mismatch: columns {bad_cols} have length {[len(processed_inputs[col]) for col in bad_cols]} "
f"while {first_col} has length {len(processed_inputs[first_col])}."
)
def prepare_inputs(key_example, indices):
key, example = key_example
fn_args = [example] if self.input_columns is None else [example[col] for col in self.input_columns]
additional_args = ()
if self.with_indices:
fn_args += (indices,)
inputs = dict(example)
return inputs, fn_args, additional_args, self.fn_kwargs
def prepare_outputs(key_example, inputs, processed_inputs):
validate_function_output(processed_inputs)
# this logic mimics the one in Dataset.map
if self.remove_columns:
for c in self.remove_columns:
if c in inputs:
del inputs[c]
if processed_inputs is key_example[1] and c in processed_inputs:
del processed_inputs[c]
transformed_inputs = {**inputs, **processed_inputs}
# no need to do features decoding here
return transformed_inputs
def apply_function(key_example, indices):
"""Utility to apply the function on a selection of columns."""
inputs, fn_args, additional_args, fn_kwargs = prepare_inputs(key_example, indices)
processed_inputs = self.function(*fn_args, *additional_args, **fn_kwargs)
return prepare_outputs(key_example, inputs, processed_inputs)
async def async_apply_function(key_example, indices):
"""Utility to apply the function on a selection of columns. Same code but async"""
inputs, fn_args, additional_args, fn_kwargs = prepare_inputs(key_example, indices)
processed_inputs = await self.function(*fn_args, *additional_args, **fn_kwargs)
return prepare_outputs(key_example, inputs, processed_inputs)
tasks: list[asyncio.Task] = []
if inspect.iscoroutinefunction(self.function):
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
self._owned_loops_and_tasks.append((loop, tasks))
else:
loop = None
def iter_outputs():
nonlocal tasks, loop
inputs_iterator = iter_batched_inputs() if self.batched else iter_inputs()
if inspect.iscoroutinefunction(self.function):
if self._state_dict:
previous_state = self.ex_iterable.state_dict()
self._state_dict["previous_state"] = previous_state
previous_state_task = None
previous_state_example_idx = self._state_dict["previous_state_example_idx"]
indices: Union[list[int], list[list[int]]] = []
for i, key_example in inputs_iterator:
indices.append(i)
tasks.append(loop.create_task(async_apply_function(key_example, i)))
# keep the total active tasks under a certain number
if len(tasks) >= self.max_num_running_async_map_functions_in_parallel:
done, pending = loop.run_until_complete(
asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
)
while tasks and len(pending) >= self.max_num_running_async_map_functions_in_parallel:
done, pending = loop.run_until_complete(
asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
)
if len(tasks) >= 10 * self.max_num_running_async_map_functions_in_parallel:
loop.run_until_complete(tasks[0])
# yield finished tasks
while tasks and tasks[0].done():
i, task = indices.pop(0), tasks.pop(0)
yield i, task.result()
if self._state_dict and task is previous_state_task:
self._state_dict["previous_state"] = previous_state
self._state_dict["num_examples_since_previous_state"] = 0
self._state_dict["previous_state_example_idx"] = previous_state_example_idx
previous_state, previous_state_task = None, None
# checkpoint
if self._state_dict and previous_state_task is None and tasks:
previous_state = self.ex_iterable.state_dict()
previous_state_task = tasks[-1]
previous_state_example_idx = current_idx
while tasks:
yield indices[0], loop.run_until_complete(tasks[0])
indices.pop(0), tasks.pop(0)
else:
if self._state_dict:
if self.batched:
self._state_dict["previous_state"] = self.ex_iterable.state_dict()
self._state_dict["num_examples_since_previous_state"] = 0
self._state_dict["previous_state_example_idx"] = current_idx
for i, key_example in inputs_iterator:
if self._state_dict:
if not self.batched:
self._state_dict["previous_state_example_idx"] = current_idx
yield i, apply_function(key_example, i)
if self._state_dict:
if self.batched:
self._state_dict["previous_state"] = self.ex_iterable.state_dict()
self._state_dict["num_examples_since_previous_state"] = 0
self._state_dict["previous_state_example_idx"] = current_idx
try:
outputs = iter_outputs()
if self.batched:
outputs = (
(key, transformed_example)
for key, transformed_batch in outputs
for transformed_example in _batch_to_examples(transformed_batch)
)
for key, transformed_example in outputs:
if self._state_dict and self._state_dict["previous_state"] is not None:
self._state_dict["num_examples_since_previous_state"] += 1
if num_examples_to_skip > 0:
num_examples_to_skip -= 1
continue
yield key, transformed_example
except (Exception, KeyboardInterrupt):
if loop:
logger.debug(f"Canceling {len(tasks)} async tasks.")
for task in tasks:
task.cancel(msg="KeyboardInterrupt")
try:
loop.run_until_complete(asyncio.gather(*tasks))
except (asyncio.CancelledError, ValueError):
logger.debug("Tasks canceled.")
raise