def __iter__()

in domainbed_measures/utils.py [0:0]


    def __iter__(self):
        for _ in range(len(self)):
            xyidx_minibatches = next(self._minibatches_iterator)
            xyidx_minibatches = [
                xyidx for xyidx in xyidx_minibatches if xyidx is not None
            ]
            if len(xyidx_minibatches[0]) == 2:
                all_x = torch.cat([x for x, _ in xyidx_minibatches])
                all_y = torch.cat([y for _, y in xyidx_minibatches])

                yield all_x, all_y
            else:
                all_x = torch.cat([x for x, _, _ in xyidx_minibatches])
                all_y = torch.cat([y for _, y, _ in xyidx_minibatches])

                all_idx = []
                for env_idx, _ in enumerate(xyidx_minibatches):
                    all_idx.append(xyidx_minibatches[env_idx][2] +
                                   self._all_cumulative_datapoints[env_idx])
                all_idx = torch.cat(all_idx)

                yield all_x, all_y, all_idx