in src/datasets/iterable_dataset.py [0:0]
def __iter__(self):
# we use this to buffer one example of each iterator to know if an iterator is exhausted
nexts = [None] * len(self.ex_iterables)
# because of that, we need to rewind 1 example when reloading the state dict
if self._state_dict:
for i in range(len(self.ex_iterables)):
if self._state_dict["previous_states"][i] is not None:
self.ex_iterables[i].load_state_dict(self._state_dict["previous_states"][i])
iterators = [iter(ex_iterable) for ex_iterable in self.ex_iterables]
indices_iterator = self._get_indices_iterator()
is_exhausted = (
np.array(self._state_dict["is_exhausted"]) if self._state_dict else np.full(len(self.ex_iterables), False)
)
for i in indices_iterator:
# if the stopping criteria is met, break the main for loop
if self.bool_strategy_func(is_exhausted):
break
# let's pick one example from the iterator at index i
if nexts[i] is None:
nexts[i] = next(iterators[i], False)
result = nexts[i]
if self._state_dict:
self._state_dict["previous_states"][i] = deepcopy(self._state_dict["ex_iterables"][i])
nexts[i] = next(iterators[i], False)
# the iterator is exhausted
if nexts[i] is False:
is_exhausted[i] = True
if self._state_dict:
self._state_dict["is_exhausted"][i] = True
# we reset it in case the stopping crtieria isn't met yet
nexts[i] = None
if self._state_dict:
self._state_dict["ex_iterables"][i] = self.ex_iterables[i]._init_state_dict()
self._state_dict["previous_states"][i] = None
iterators[i] = iter(self.ex_iterables[i])
if result is not False:
yield result