in src/modules/transformer_decoder.py [0:0]
def sample(self,
features,
mask,
greedy=True,
temperature=1.0,
first_token_value=0,
replacement=True):
incremental_state = {}
# create dummy previous word
fs = features.size(0)
first_word = torch.ones(fs) * first_token_value
first_word = first_word.to(device).long()
sampled_ids = [first_word]
logits = []
for i in range(self.seq_length):
# forward
outputs = self.forward(features, mask, torch.stack(sampled_ids, 1), incremental_state)
outputs = outputs.squeeze(1)
if not replacement:
# predicted mask
if i == 0:
predicted_mask = torch.zeros(outputs.shape).float().to(device)
else:
batch_ind = [j for j in range(fs) if sampled_ids[i][j] != 0]
sampled_ids_new = sampled_ids[i][batch_ind]
predicted_mask[batch_ind, sampled_ids_new] = float('-inf')
# mask previously selected ids
outputs += predicted_mask
# add outputs to list
logits.append(outputs)
if greedy:
_, predicted = outputs.max(1)
predicted = predicted.detach()
else:
k = 10
prob_prev = torch.div(outputs.squeeze(1), temperature)
prob_prev = torch.nn.functional.softmax(prob_prev, dim=-1).data
# top k random sampling
prob_prev_topk, indices = torch.topk(prob_prev, k=k, dim=1)
predicted = torch.multinomial(prob_prev_topk, 1).view(-1)
predicted = torch.index_select(indices, dim=1, index=predicted)[:, 0].detach()
sampled_ids.append(predicted)
sampled_ids = torch.stack(sampled_ids[1:], 1)
logits = torch.stack(logits, 1)
return sampled_ids, logits