in src/datasets/iterable_dataset.py [0:0]
def _iter_arrow(self) -> Iterator[tuple[Key, pa.Table]]:
"""Iterate over sub-tables of size `batch_size`."""
if self._state_dict and self._state_dict["previous_state"]:
self.ex_iterable.load_state_dict(self._state_dict["previous_state"])
if self.ex_iterable.iter_arrow:
iterator = self.ex_iterable.iter_arrow()
else:
iterator = _convert_to_arrow(self.ex_iterable, batch_size=1)
if self.batch_size is None or self.batch_size <= 0:
if self._state_dict and self._state_dict["batch_idx"] > 0:
return
all_pa_table = pa.concat_tables([pa_table for _, pa_table in iterator])
if self._state_dict:
self._state_dict["batch_idx"] = 1
yield "all", all_pa_table
return
keys_buffer = []
chunks_buffer = []
chunks_buffer_size = 0
num_chunks_to_skip = self._state_dict["num_chunks_since_previous_state"] if self._state_dict else 0
chunk_length_to_crop = self._state_dict["cropped_chunk_length"] if self._state_dict else 0
if self._state_dict:
previous_state = self.ex_iterable.state_dict()
self._state_dict["previous_state"] = previous_state
for key, pa_table in iterator:
for num_chunks_since_previous_state, chunk in enumerate(pa_table.to_reader(max_chunksize=self.batch_size)):
if num_chunks_to_skip > 1:
num_chunks_to_skip -= 1
continue
elif num_chunks_to_skip == 1 and chunk_length_to_crop == 0:
num_chunks_to_skip -= 1
continue
elif num_chunks_to_skip == 1 and chunk_length_to_crop > 0:
chunk = chunk.slice(chunk_length_to_crop, len(chunk) - chunk_length_to_crop)
num_chunks_to_skip = 0
chunk_length_to_crop = 0
if len(chunk) == 0:
continue
if chunks_buffer_size + len(chunk) < self.batch_size:
keys_buffer.append(key)
chunks_buffer.append(chunk)
chunks_buffer_size += len(chunk)
continue
elif chunks_buffer_size + len(chunk) == self.batch_size:
keys_buffer.append(key)
chunks_buffer.append(chunk)
new_key = "_".join(str(_key) for _key in keys_buffer)
if self._state_dict:
self._state_dict["batch_idx"] += 1
self._state_dict["num_chunks_since_previous_state"] += len(chunks_buffer)
self._state_dict["cropped_chunk_length"] = 0
yield new_key, pa.Table.from_batches(chunks_buffer)
keys_buffer = []
chunks_buffer = []
chunks_buffer_size = 0
if self._state_dict:
self._state_dict["previous_state"] = previous_state
self._state_dict["num_chunks_since_previous_state"] = num_chunks_since_previous_state + 1
else:
cropped_chunk_length = self.batch_size - chunks_buffer_size
keys_buffer.append(f"{key}[:{cropped_chunk_length}]")
chunks_buffer.append(chunk.slice(0, cropped_chunk_length))
new_key = "_".join(str(_key) for _key in keys_buffer)
if self._state_dict:
self._state_dict["batch_idx"] += 1
self._state_dict["num_chunks_since_previous_state"] += len(chunks_buffer)
self._state_dict["cropped_chunk_length"] = cropped_chunk_length
yield new_key, pa.Table.from_batches(chunks_buffer)
keys_buffer = [f"{key}[{cropped_chunk_length}:]"]
chunks_buffer = [chunk.slice(cropped_chunk_length, len(chunk) - cropped_chunk_length)]
chunks_buffer_size = len(chunk) - cropped_chunk_length
if self._state_dict:
self._state_dict["previous_state"] = previous_state
self._state_dict["num_chunks_since_previous_state"] = num_chunks_since_previous_state
if self._state_dict:
previous_state = self.ex_iterable.state_dict()
if not self.drop_last_batch and chunks_buffer:
new_key = "_".join(str(_key) for _key in keys_buffer)
if self._state_dict:
self._state_dict["previous_state"] = previous_state
self._state_dict["batch_idx"] += 1
self._state_dict["num_chunks_since_previous_state"] = 0
self._state_dict["cropped_chunk_length"] = 0
yield new_key, pa.Table.from_batches(chunks_buffer)