in utils_nlp/models/transformers/bertsum/beam.py [0:0]
def advance(self, word_probs, attn_out):
"""
Given prob over words for every last beam `wordLk` and attention
`attn_out`: Compute and update the beam search.
Parameters:
* `word_probs`- probs of advancing from the last step (K x words)
* `attn_out`- attention at the last step
Returns: True if beam search is complete.
"""
num_words = word_probs.size(1)
if self.stepwise_penalty:
self.global_scorer.update_score(self, attn_out)
# force the output to be longer than self.min_length
cur_len = len(self.next_ys)
if cur_len < self.min_length:
for k in range(len(word_probs)):
word_probs[k][self._eos] = -1e20
# Sum the previous scores.
if len(self.prev_ks) > 0:
beam_scores = word_probs + self.scores.unsqueeze(1).expand_as(word_probs)
# Don't let EOS have children.
for i in range(self.next_ys[-1].size(0)):
if self.next_ys[-1][i] == self._eos:
beam_scores[i] = -1e20
# Block ngram repeats
if self.block_ngram_repeat > 0:
ngrams = []
le = len(self.next_ys)
for j in range(self.next_ys[-1].size(0)):
hyp, _ = self.get_hyp(le - 1, j)
ngrams = set()
fail = False
gram = []
for i in range(le - 1):
# Last n tokens, n = block_ngram_repeat
gram = (gram + [hyp[i].item()])[-self.block_ngram_repeat :]
# Skip the blocking if it is in the exclusion list
if set(gram) & self.exclusion_tokens:
continue
if tuple(gram) in ngrams:
fail = True
ngrams.add(tuple(gram))
if fail:
beam_scores[j] = -10e20
else:
beam_scores = word_probs[0]
flat_beam_scores = beam_scores.view(-1)
best_scores, best_scores_id = flat_beam_scores.topk(self.size, 0, True, True)
self.all_scores.append(self.scores)
self.scores = best_scores
# best_scores_id is flattened beam x word array, so calculate which
# word and beam each score came from
prev_k = best_scores_id / num_words
self.prev_ks.append(prev_k)
self.next_ys.append((best_scores_id - prev_k * num_words))
self.attn.append(attn_out.index_select(0, prev_k))
self.global_scorer.update_global_state(self)
for i in range(self.next_ys[-1].size(0)):
if self.next_ys[-1][i] == self._eos:
global_scores = self.global_scorer.score(self, self.scores)
s = global_scores[i]
self.finished.append((s, len(self.next_ys) - 1, i))
# End condition is when top-of-beam is EOS and no global score.
if self.next_ys[-1][0] == self._eos:
self.all_scores.append(self.scores)
self.eos_top = True