in pytorch_translate/research/multisource/multisource_decode.py [0:0]
def finalize_hypos(step, bbsz_idx, eos_scores, unfinalized_scores=None):
"""
Finalize the given hypotheses at this step, while keeping the total
number of finalized hypotheses per sentence <= beam_size.
Note: the input must be in the desired finalization order, so that
hypotheses that appear earlier in the input are preferred to those
that appear later.
Args:
step: current time step
bbsz_idx: A vector of indices in the range [0, bsz*beam_size),
indicating which hypotheses to finalize
eos_scores: A vector of the same size as bbsz_idx containing
scores for each hypothesis
unfinalized_scores: A vector containing scores for all
unfinalized hypotheses
"""
assert bbsz_idx.numel() == eos_scores.numel()
# clone relevant token and attention tensors
tokens_clone = tokens.index_select(0, bbsz_idx)
tokens_clone = tokens_clone[
:, 1 : step + 2
] # skip the first index, which is EOS
tokens_clone[:, step] = self.eos
attn_clone = attn.index_select(0, bbsz_idx)[:, :, 1 : step + 2]
# compute scores per token position
pos_scores = scores.index_select(0, bbsz_idx)[:, : step + 1]
pos_scores[:, step] = eos_scores
# convert from cumulative to per-position scores
pos_scores[:, 1:] = pos_scores[:, 1:] - pos_scores[:, :-1]
# normalize sentence-level scores
if self.normalize_scores:
eos_scores /= (step + 1) ** self.len_penalty
sents_seen = set()
for i, (idx, score) in enumerate(
zip(bbsz_idx.tolist(), eos_scores.tolist())
):
sent = idx // beam_size
sents_seen.add(sent)
def get_hypo():
_, alignment = attn_clone[i].max(dim=0)
return {
"tokens": tokens_clone[i],
"score": score,
"attention": attn_clone[i], # src_len x tgt_len
"alignment": alignment,
"positional_scores": pos_scores[i],
}
if len(finalized[sent]) < beam_size:
finalized[sent].append(get_hypo())
elif not self.stop_early and score > worst_finalized[sent]["score"]:
# replace worst hypo for this sentence with new/better one
worst_idx = worst_finalized[sent]["idx"]
if worst_idx is not None:
finalized[sent][worst_idx] = get_hypo()
# find new worst finalized hypo for this sentence
idx, s = min(
enumerate(finalized[sent]), key=lambda r: r[1]["score"]
)
worst_finalized[sent] = {"score": s["score"], "idx": idx}
# return number of hypotheses finished this step
num_finished = 0
for sent in sents_seen:
# check termination conditions for this sentence
if not finished[sent] and is_finished(sent, step, unfinalized_scores):
finished[sent] = True
num_finished += 1
return num_finished