in grok/data.py [0:0]
def __next__(self) -> Dict[str, Tensor]:
"""
Returns one batch of data.
:raises: StopIteration when we're out of data
:returns: batch tensor of shape (self.batchsize, tokens_per_eq)
"""
batch_begin = self.index * self.batchsize
if batch_begin > len(self.dataset) - 1:
self.reset_iteration()
raise StopIteration
indices = self.permutation[batch_begin : batch_begin + self.batchsize]
text = self.dataset.data[indices, :-1]
target = self.dataset.data[indices, 1:]
batch = {"text": text.to(self.device), "target": target.to(self.device)}
self.index += 1
return batch