in model/train.py [0:0]
def generate_tokens(num_samples, temperature, batch_size=128, seq_len=2048):
# Currently only generates unconditional (assuming src is None)
# First token has to be expanded into one-hot if data is in index form
assert num_samples % batch_size == 0
num_samples = num_samples // batch_size
result = torch.tensor([], dtype=torch.long).to(device)
cache_tgt_len, cache_mem_len = (
model.module.generator.tgt_len,
model.module.generator.mem_len,
)
# Reset params for sampling
model.module.generator.reset_length(1, seq_len) # Use mem_len=bert_len
for _ in range(num_samples):
seq = [torch.zeros(batch_size, dtype=torch.long, device=device)[None, :]]
sample_mems = None
status_vec = None
for _ in range(
seq_len - 1
): # Since start token is chosen and bert tgt len is fixed
inp = seq[-1]
if cfg.TRAIN.append_note_status:
bptt, batch_size = inp.shape
if status_vec is None:
status_vec = inp.new_zeros((bptt, batch_size, vocab.vec_len), dtype=torch.bool)
else:
status_vec = status_vec[-1:, :, :]
vocab.update_status_vec(inp, status_vec)
ret = model.module.generator.forward_generate_gumbel(
inp, temperature, sample_mems, status_vec=status_vec
)
logits, sample_mems = ret
seq.append(torch.argmax(logits[0], dim=-1)[None, :])
result = torch.cat([result, torch.cat(seq, dim=0)], dim=1)
# Reset params for sampling
model.module.generator.reset_length(
cache_tgt_len, cache_mem_len
) # Use mem_len=bert_len
return result