def generate_tokens()

in model/train.py [0:0]


def generate_tokens(num_samples, temperature, batch_size=128, seq_len=2048):
    # Currently only generates unconditional (assuming src is None)

    # First token has to be expanded into one-hot if data is in index form
    assert num_samples % batch_size == 0

    num_samples = num_samples // batch_size
    result = torch.tensor([], dtype=torch.long).to(device)

    cache_tgt_len, cache_mem_len = (
        model.module.generator.tgt_len,
        model.module.generator.mem_len,
    )

    # Reset params for sampling
    model.module.generator.reset_length(1, seq_len)  # Use mem_len=bert_len

    for _ in range(num_samples):
        seq = [torch.zeros(batch_size, dtype=torch.long, device=device)[None, :]]
        sample_mems = None
        status_vec = None


        for _ in range(
                seq_len - 1
        ):  # Since start token is chosen and bert tgt len is fixed

            inp = seq[-1]
            if cfg.TRAIN.append_note_status:
                bptt, batch_size = inp.shape
                if status_vec is None:
                    status_vec = inp.new_zeros((bptt, batch_size, vocab.vec_len), dtype=torch.bool)
                else:
                    status_vec = status_vec[-1:, :, :]
                vocab.update_status_vec(inp, status_vec)

            ret = model.module.generator.forward_generate_gumbel(
                inp, temperature, sample_mems, status_vec=status_vec
            )

            logits, sample_mems = ret
            seq.append(torch.argmax(logits[0], dim=-1)[None, :])

        result = torch.cat([result, torch.cat(seq, dim=0)], dim=1)

    # Reset params for sampling
    model.module.generator.reset_length(
        cache_tgt_len, cache_mem_len
    )  # Use mem_len=bert_len

    return result