in pytorch_translate/ensemble_export.py [0:0]
def generate(self, models, src_tokens, src_lengths, prefix_tokens=None):
# TODO: model ensemble
assert len(models) == 1, "only support single model"
model = models[0]
bsz, src_len = src_tokens.size()
sent_idxs = torch.arange(bsz)
# encoding
encoder_out = model.encoder(src_tokens, src_lengths)
# initialize buffers (very model specific, with length prediction or not)
prev_decoder_out = model.initialize_output_tokens(encoder_out, src_tokens)
prev_output_tokens = prev_decoder_out.output_tokens.clone()
finalized_tokens_list = [torch.tensor(0) for _ in range(bsz)]
finalized_scores_list = [torch.tensor(0) for _ in range(bsz)]
finalized_attns_list = [torch.tensor(0) for _ in range(bsz)]
finalized_alignments_list = [torch.tensor(0) for _ in range(bsz)]
for step in range(self.max_iter + 1):
prev_decoder_out = prev_decoder_out._replace(
step=step, max_step=self.max_iter + 1
)
decoder_out = model.forward_decoder(
prev_decoder_out,
encoder_out,
eos_penalty=self.eos_penalty,
max_ratio=self.max_ratio,
decoding_format=self.decoding_format,
)
terminated, output_tokens, output_scores, output_attn = is_a_loop(
self.pad,
prev_output_tokens,
decoder_out.output_tokens,
decoder_out.output_scores,
decoder_out.attn,
)
decoder_out = decoder_out._replace(
output_tokens=output_tokens,
output_scores=output_scores,
attn=output_attn,
)
terminated = last_step(step, self.max_iter, terminated)
# collect finalized sentences
finalized_idxs = sent_idxs[terminated]
finalized_tokens = decoder_out.output_tokens[terminated]
finalized_scores = decoder_out.output_scores[terminated]
finalized_attn = (
None if decoder_out.attn is None else decoder_out.attn[terminated]
)
finalized_tokens_list = finalize_hypos_loop_tokens(
finalized_tokens_list,
finalized_idxs,
self.pad,
finalized_tokens,
finalized_scores,
)
finalized_scores_list = finalize_hypos_loop_scores(
finalized_scores_list,
finalized_idxs,
self.pad,
finalized_tokens,
finalized_scores,
)
finalized_attns_list, finalized_alignments_list = finalize_hypos_loop_attns(
finalized_attns_list,
finalized_alignments_list,
finalized_idxs,
self.pad,
finalized_tokens,
finalized_scores,
finalized_attn,
)
# for next step
not_terminated = ~terminated
prev_decoder_out = decoder_out._replace(
output_tokens=script_skip_tensor(
decoder_out.output_tokens, not_terminated
),
output_scores=script_skip_tensor(
decoder_out.output_scores, not_terminated
),
attn=decoder_out.attn,
step=decoder_out.step,
max_step=decoder_out.max_step,
)
encoder_out = EncoderOut(
encoder_out=script_skip_tensor(encoder_out.encoder_out, ~terminated),
encoder_padding_mask=None,
encoder_embedding=script_skip_tensor(
encoder_out.encoder_embedding, ~terminated
),
encoder_states=None,
src_tokens=None,
src_lengths=None,
)
sent_idxs = script_skip_tensor(sent_idxs, not_terminated)
prev_output_tokens = prev_decoder_out.output_tokens.clone()
return (
finalized_tokens_list,
finalized_scores_list,
finalized_attns_list,
finalized_alignments_list,
)