in model/data_utils.py [0:0]
def iterator():
perm = np.arange(total_sample_num)
if do_shuffle:
rng = np.random.RandomState(seed)
rng.shuffle(perm)
assert batch_size < total_sample_num
tracker_list = [(i, 0) for i in range(batch_size)]
next_idx = batch_size
data = torch.LongTensor(bptt, batch_size)
while True:
# Generate the samples
# Fill with pad_id
data[:] = self.vocab.pad_id
batch_token_num = 0
for i in range(batch_size):
idx, pos = tracker_list[i]
while idx < total_sample_num:
seq_id = perm[idx]
seq_length = split_seq_lengths[seq_id]
if bptt > seq_length:
idx, pos = next_idx, 0
tracker_list[i] = (idx, pos)
next_idx += 1
continue
else:
# Fill elements
pos = np.random.randint(0, seq_length - bptt + 1)
data[:bptt, i] = split_data[seq_id][pos: pos + bptt]
batch_token_num += bptt
tracker_list[i] = (idx, pos + bptt)
break
if batch_token_num == 0:
# Haven't found anything to fill. This indicates we have reached the end
if do_shuffle:
rng.shuffle(perm)
else:
return # One pass dataloader when do_shuffle is False
tracker_list = [(i, 0) for i in range(batch_size)]
next_idx = batch_size
continue
yield data.to(device), batch_token_num