in domainbed_measures/utils.py [0:0]
def __iter__(self):
for _ in range(len(self)):
xyidx_minibatches = next(self._minibatches_iterator)
xyidx_minibatches = [
xyidx for xyidx in xyidx_minibatches if xyidx is not None
]
if len(xyidx_minibatches[0]) == 2:
all_x = torch.cat([x for x, _ in xyidx_minibatches])
all_y = torch.cat([y for _, y in xyidx_minibatches])
yield all_x, all_y
else:
all_x = torch.cat([x for x, _, _ in xyidx_minibatches])
all_y = torch.cat([y for _, y, _ in xyidx_minibatches])
all_idx = []
for env_idx, _ in enumerate(xyidx_minibatches):
all_idx.append(xyidx_minibatches[env_idx][2] +
self._all_cumulative_datapoints[env_idx])
all_idx = torch.cat(all_idx)
yield all_x, all_y, all_idx