in captioning/models/AttModel_orig.py [0:0]
def _diverse_sample(self, fc_feats, att_feats, att_masks=None, opt={}):
sample_method = opt.get('sample_method', 'greedy')
beam_size = opt.get('beam_size', 1)
temperature = opt.get('temperature', 1.0)
group_size = opt.get('group_size', 1)
diversity_lambda = opt.get('diversity_lambda', 0.5)
decoding_constraint = opt.get('decoding_constraint', 0)
block_trigrams = opt.get('block_trigrams', 0)
remove_bad_endings = opt.get('remove_bad_endings', 0)
batch_size = fc_feats.size(0)
state = self.init_hidden(batch_size)
p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)
trigrams_table = [[] for _ in range(group_size)] # will be a list of batch_size dictionaries
seq_table = [fc_feats.new_full((batch_size, self.seq_length), self.pad_idx, dtype=torch.long) for _ in range(group_size)]
seqLogprobs_table = [fc_feats.new_zeros(batch_size, self.seq_length) for _ in range(group_size)]
state_table = [self.init_hidden(batch_size) for _ in range(group_size)]
for tt in range(self.seq_length + group_size):
for divm in range(group_size):
t = tt - divm
seq = seq_table[divm]
seqLogprobs = seqLogprobs_table[divm]
trigrams = trigrams_table[divm]
if t >= 0 and t <= self.seq_length-1:
if t == 0: # input <bos>
it = fc_feats.new_full([batch_size], self.bos_idx, dtype=torch.long)
else:
it = seq[:, t-1] # changed
logprobs, state_table[divm] = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state_table[divm]) # changed
logprobs = F.log_softmax(logprobs / temperature, dim=-1)
# Add diversity
if divm > 0:
unaug_logprobs = logprobs.clone()
for prev_choice in range(divm):
prev_decisions = seq_table[prev_choice][:, t]
logprobs[:, prev_decisions] = logprobs[:, prev_decisions] - diversity_lambda
if decoding_constraint and t > 0:
tmp = logprobs.new_zeros(logprobs.size())
tmp.scatter_(1, seq[:,t-1].data.unsqueeze(1), float('-inf'))
logprobs = logprobs + tmp
if remove_bad_endings and t > 0:
tmp = logprobs.new_zeros(logprobs.size())
prev_bad = np.isin(seq[:,t-1].data.cpu().numpy(), self.bad_endings_ix)
# Impossible to generate remove_bad_endings
tmp[torch.from_numpy(prev_bad.astype('uint8')), 0] = float('-inf')
logprobs = logprobs + tmp
# Mess with trigrams
if block_trigrams and t >= 3:
# Store trigram generated at last step
prev_two_batch = seq[:,t-3:t-1]
for i in range(batch_size): # = seq.size(0)
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
current = seq[i][t-1]
if t == 3: # initialize
trigrams.append({prev_two: [current]}) # {LongTensor: list containing 1 int}
elif t > 3:
if prev_two in trigrams[i]: # add to list
trigrams[i][prev_two].append(current)
else: # create list
trigrams[i][prev_two] = [current]
# Block used trigrams at next step
prev_two_batch = seq[:,t-2:t]
mask = torch.zeros(logprobs.size(), requires_grad=False).cuda() # batch_size x vocab_size
for i in range(batch_size):
prev_two = (prev_two_batch[i][0].item(), prev_two_batch[i][1].item())
if prev_two in trigrams[i]:
for j in trigrams[i][prev_two]:
mask[i,j] += 1
# Apply mask to log probs
#logprobs = logprobs - (mask * 1e9)
alpha = 2.0 # = 4
logprobs = logprobs + (mask * -0.693 * alpha) # ln(1/2) * alpha (alpha -> infty works best)
it, sampleLogprobs = self.sample_next_word(logprobs, sample_method, 1)
# stop when all finished
if t == 0:
unfinished = it != self.eos_idx
else:
unfinished = seq[:,t-1] != self.pad_idx & seq[:,t-1] != self.eos_idx
it[~unfinished] = self.pad_idx
unfinished = unfinished & (it != self.eos_idx) # changed
seq[:,t] = it
seqLogprobs[:,t] = sampleLogprobs.view(-1)
return torch.stack(seq_table, 1).reshape(batch_size * group_size, -1), torch.stack(seqLogprobs_table, 1).reshape(batch_size * group_size, -1)