in custom/gpt2/run_gpt2.py [0:0]
def sample_sequence(model, prefix_batch, prefix_length, continuation_length, top_k, top_p):
continuation_logits = []
context = prefix_batch
assert context.size(1) == prefix_length
prev = context
output = context
past = None
for i in range(continuation_length):
logits, past = model(prev, past=past)
logits = logits[:, -1, :]
if top_k == 1 and top_p == 0:
prev = logits.argmax(dim=1, keepdim=True)
else:
filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
prev = F.softmax(filtered_logits, dim=-1).multinomial(num_samples=1)
continuation_logits.append(logits)
output = torch.cat((output, prev), dim=1)
continuation_logits = torch.stack(continuation_logits, 1)
return output, continuation_logits