in model/data_utils.py [0:0]
def iterator():
perm = np.arange(total_sample_num)
if do_shuffle:
rng = np.random.RandomState(seed)
rng.shuffle(perm)
assert batch_size < total_sample_num
tracker_list = [(i, 0) for i in range(batch_size)]
next_idx = batch_size
data = torch.LongTensor(bptt, batch_size)
target = torch.LongTensor(bptt, batch_size)
reset_mem = torch.BoolTensor(batch_size)
if self.cfg.TRAIN.append_note_status:
status_vec = torch.zeros((bptt, batch_size, self._vocab.vec_len), dtype=torch.bool)
else:
status_vec = None
while True:
# Generate the samples
# Fill with pad_id
data[:] = self.vocab.pad_id
target[:] = self.vocab.pad_id
reset_mem[:] = False
batch_token_num = 0
for i in range(batch_size):
idx, pos = tracker_list[i]
while idx < total_sample_num:
seq_id = perm[idx]
seq_length = split_seq_lengths[seq_id]
if pos + 1 >= seq_length:
idx, pos = next_idx, 0
tracker_list[i] = (idx, pos)
next_idx += 1
reset_mem[i] = True
continue
else:
if self.cfg.TRAIN.random_crop:
offset = 0
if self.cfg.TRAIN.mem_length == 0:
offset = bptt
if pos == 0:
# print("Picking random span")
pos = np.random.randint(0, seq_length - 1 - offset) # Atleast bptt
n_new = min(seq_length - 1 - pos, bptt)
data[:n_new, i] = split_data[seq_id][pos: pos + n_new]
target[:n_new, i] = split_data[seq_id][
(pos + 1): (pos + 1 + n_new)]
batch_token_num += n_new
tracker_list[i] = (idx, pos + n_new)
if self.cfg.TRAIN.mem_length == 0 and self.cfg.TRAIN.random_crop:
# Move on if memlen==0
idx, pos = next_idx, 0
tracker_list[i] = (idx, pos)
next_idx += 1
reset_mem[i] = True
break
if batch_token_num == 0:
# Haven't found anything to fill. This indicates we have reached the end
if do_shuffle:
rng.shuffle(perm)
else:
return # One pass dataloader when do_shuffle is False
tracker_list = [(i, 0) for i in range(batch_size)]
next_idx = batch_size
continue
if self.cfg.TRAIN.append_note_status:
# Reset status vec for new midi file
status_vec[:, reset_mem, :] = False
self._vocab.update_status_vec(data, status_vec)
status_vec = status_vec.to(device)
yield data.to(device), target.to(device), reset_mem.to(
device), batch_token_num, status_vec