in model/data_utils.py [0:0]
def iterator():
data = torch.LongTensor(bptt, batch_size)
target = torch.LongTensor(bptt, batch_size)
if self.cfg.TRAIN.append_note_status:
status_vec = torch.zeros((bptt, batch_size, self._vocab.vec_len), dtype=torch.bool)
else:
status_vec = None
for batch_begin in range(0, total_sample_num, batch_size):
reset_all_mem = True
batch_end = min(batch_begin + batch_size, total_sample_num)
max_seq_length = max(split_seq_lengths[batch_begin:batch_end])
for seq_begin in range(0, max_seq_length - 1, bptt):
data[:] = self.vocab.pad_id
target[:] = self.vocab.pad_id
batch_token_num = 0
for i in range(batch_begin, batch_end):
if split_seq_lengths[i] > seq_begin + 1:
n_new = (
min(seq_begin + bptt, split_seq_lengths[i] - 1)
- seq_begin
)
data[:n_new, i - batch_begin] = split_data[i][
seq_begin: seq_begin + n_new
]
target[:n_new, i - batch_begin] = split_data[i][
(seq_begin + 1): (seq_begin + n_new + 1)
]
batch_token_num += n_new
if self.cfg.TRAIN.append_note_status:
# Reset status vec for new midi file
if reset_all_mem:
status_vec[:] = False
self._vocab.update_status_vec(data, status_vec)
status_vec = status_vec.to(device)
yield data.to(device), target.to(
device), reset_all_mem, batch_token_num, status_vec
reset_all_mem = False