in bertabs/modeling_bertabs.py [0:0]
def _fast_translate_batch(self, batch, max_length, min_length=0):
"""Beam Search using the encoder inputs contained in `batch`."""
# The batch object is funny
# Instead of just looking at the size of the arguments we encapsulate
# a size argument.
# Where is it defined?
beam_size = self.beam_size
batch_size = batch.batch_size
src = batch.src
segs = batch.segs
mask_src = batch.mask_src
src_features = self.model.bert(src, segs, mask_src)
dec_states = self.model.decoder.init_decoder_state(src, src_features, with_cache=True)
device = src_features.device
# Tile states and memory beam_size times.
dec_states.map_batch_fn(lambda state, dim: tile(state, beam_size, dim=dim))
src_features = tile(src_features, beam_size, dim=0)
batch_offset = torch.arange(batch_size, dtype=torch.long, device=device)
beam_offset = torch.arange(0, batch_size * beam_size, step=beam_size, dtype=torch.long, device=device)
alive_seq = torch.full([batch_size * beam_size, 1], self.start_token, dtype=torch.long, device=device)
# Give full probability to the first beam on the first step.
topk_log_probs = torch.tensor([0.0] + [float("-inf")] * (beam_size - 1), device=device).repeat(batch_size)
# Structure that holds finished hypotheses.
hypotheses = [[] for _ in range(batch_size)] # noqa: F812
results = {}
results["predictions"] = [[] for _ in range(batch_size)] # noqa: F812
results["scores"] = [[] for _ in range(batch_size)] # noqa: F812
results["gold_score"] = [0] * batch_size
results["batch"] = batch
for step in range(max_length):
decoder_input = alive_seq[:, -1].view(1, -1)
# Decoder forward.
decoder_input = decoder_input.transpose(0, 1)
dec_out, dec_states = self.model.decoder(decoder_input, src_features, dec_states, step=step)
# Generator forward.
log_probs = self.generator(dec_out.transpose(0, 1).squeeze(0))
vocab_size = log_probs.size(-1)
if step < min_length:
log_probs[:, self.end_token] = -1e20
# Multiply probs by the beam probability.
log_probs += topk_log_probs.view(-1).unsqueeze(1)
alpha = self.global_scorer.alpha
length_penalty = ((5.0 + (step + 1)) / 6.0) ** alpha
# Flatten probs into a list of possibilities.
curr_scores = log_probs / length_penalty
if self.args.block_trigram:
cur_len = alive_seq.size(1)
if cur_len > 3:
for i in range(alive_seq.size(0)):
fail = False
words = [int(w) for w in alive_seq[i]]
words = [self.vocab.ids_to_tokens[w] for w in words]
words = " ".join(words).replace(" ##", "").split()
if len(words) <= 3:
continue
trigrams = [(words[i - 1], words[i], words[i + 1]) for i in range(1, len(words) - 1)]
trigram = tuple(trigrams[-1])
if trigram in trigrams[:-1]:
fail = True
if fail:
curr_scores[i] = -10e20
curr_scores = curr_scores.reshape(-1, beam_size * vocab_size)
topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1)
# Recover log probs.
topk_log_probs = topk_scores * length_penalty
# Resolve beam origin and true word ids.
topk_beam_index = topk_ids.div(vocab_size)
topk_ids = topk_ids.fmod(vocab_size)
# Map beam_index to batch_index in the flat representation.
batch_index = topk_beam_index + beam_offset[: topk_beam_index.size(0)].unsqueeze(1)
select_indices = batch_index.view(-1)
# Append last prediction.
alive_seq = torch.cat([alive_seq.index_select(0, select_indices), topk_ids.view(-1, 1)], -1)
is_finished = topk_ids.eq(self.end_token)
if step + 1 == max_length:
is_finished.fill_(1)
# End condition is top beam is finished.
end_condition = is_finished[:, 0].eq(1)
# Save finished hypotheses.
if is_finished.any():
predictions = alive_seq.view(-1, beam_size, alive_seq.size(-1))
for i in range(is_finished.size(0)):
b = batch_offset[i]
if end_condition[i]:
is_finished[i].fill_(1)
finished_hyp = is_finished[i].nonzero().view(-1)
# Store finished hypotheses for this batch.
for j in finished_hyp:
hypotheses[b].append((topk_scores[i, j], predictions[i, j, 1:]))
# If the batch reached the end, save the n_best hypotheses.
if end_condition[i]:
best_hyp = sorted(hypotheses[b], key=lambda x: x[0], reverse=True)
score, pred = best_hyp[0]
results["scores"][b].append(score)
results["predictions"][b].append(pred)
non_finished = end_condition.eq(0).nonzero().view(-1)
# If all sentences are translated, no need to go further.
if len(non_finished) == 0:
break
# Remove finished batches for the next step.
topk_log_probs = topk_log_probs.index_select(0, non_finished)
batch_index = batch_index.index_select(0, non_finished)
batch_offset = batch_offset.index_select(0, non_finished)
alive_seq = predictions.index_select(0, non_finished).view(-1, alive_seq.size(-1))
# Reorder states.
select_indices = batch_index.view(-1)
src_features = src_features.index_select(0, select_indices)
dec_states.map_batch_fn(lambda state, dim: state.index_select(dim, select_indices))
return results