in src/modules/rnn_decoder.py [0:0]
def sample(self,
features,
mask,
greedy=True,
temperature=1.0,
first_token_value=0,
replacement=True):
"""Generate captions for given image features."""
logits = []
avg_feats = torch.mean(features, dim=-1)
inputs = avg_feats
states = None
fs = features.size(0)
prev_word = torch.ones(fs, 1).cuda().long() * first_token_value
sampled_ids = [prev_word]
prev_word = self.embed(prev_word).squeeze(1)
for i in range(self.seq_length):
v, states, att_coeffs = self.core(inputs, features, prev_word, states)
inputs = v
outputs = self.linear(v)
outputs = outputs.squeeze(1)
if not replacement:
# predicted mask
if i == 0:
predicted_mask = torch.zeros(outputs.shape).float().to(device)
else:
batch_ind = [j for j in range(fs) if sampled_ids[i][j] != 0]
sampled_ids_new = sampled_ids[i][batch_ind]
predicted_mask[batch_ind, sampled_ids_new] = float('-inf')
# mask previously selected ids
outputs += predicted_mask
logits.append(outputs)
# outputs = torch.nn.functional.log_softmax(outputs, dim=1)
if greedy:
_, predicted = outputs.max(1)
predicted = predicted.detach()
else:
k = 10
prob_prev = torch.div(outputs.squeeze(1), temperature)
prob_prev = torch.nn.functional.softmax(prob_prev, dim=-1).data
# top k random sampling
prob_prev_topk, indices = torch.topk(prob_prev, k=k, dim=1)
predicted = torch.multinomial(prob_prev_topk, 1).view(-1)
predicted = torch.index_select(indices, dim=1, index=predicted)[:, 0].detach()
sampled_ids.append(predicted)
prev_word = self.embed(predicted)
logits = torch.stack(logits, 1)
sampled_ids = torch.stack(sampled_ids[1:], 1)
return sampled_ids, logits