in mm_action_prediction/models/decoder.py [0:0]
def forward_beamsearch_single(self, batch, encoder_output, mode_params):
"""Evaluates the model using beam search with batch size 1.
NOTE: Current implementation only supports beam search for batch size 1
and for RNN text_encoder (will be extended for transformers)
Args:
batch: Dictionary of inputs, with batch size of 1
beam_size: Number of beams
Returns:
top_beams: Dictionary of top beams
"""
# Initializations and aliases.
# Tensors are either on CPU or GPU.
LENGTH_NORM = True
self.host = torch.cuda if self.params["use_gpu"] else torch
end_token = self.params["end_token"]
start_token = self.params["start_token"]
beam_size = mode_params["beam_size"]
max_decoder_len = self.params["max_decoder_len"]
if self.params["text_encoder"] == "transformer":
hidden_state = encoder_output["hidden_states_all"].transpose(0, 1)
max_enc_len, batch_size, enc_embed_size = hidden_state.shape
hidden_state_expand = hidden_state.expand(
max_enc_len, beam_size, enc_embed_size
)
enc_pad_mask = batch["user_utt"] == batch["pad_token"]
enc_pad_mask = support.flatten(enc_pad_mask, 1, 1)
enc_pad_mask_expand = enc_pad_mask.expand(beam_size, max_enc_len)
if (
self.no_peek_mask is None
or self.no_peek_mask.size(0) != max_decoder_len
):
self.no_peek_mask = self._generate_no_peek_mask(max_decoder_len)
elif self.params["text_encoder"] == "lstm":
hidden_state = encoder_output["hidden_state"]
if (
self.params["use_bahdanau_attention"]
and self.params["encoder"] != "tf_idf"
and self.params["text_encoder"] == "lstm"
):
encoder_states = encoder_output["hidden_states_all"]
encoder_states_proj = self.attention_net(encoder_states)
enc_mask = (batch["user_utt"] == batch["pad_token"]).unsqueeze(-1)
enc_mask = support.flatten(enc_mask, 1, 1)
# Per instance initializations.
# Copy the hidden state beam_size number of times.
if hidden_state is not None:
hidden_state = [ii.repeat(1, beam_size, 1) for ii in hidden_state]
beams = {-1: self.host.LongTensor(1, beam_size).fill_(start_token)}
beam_scores = self.host.FloatTensor(beam_size, 1).fill_(0.)
finished_beams = self.host.ByteTensor(beam_size, 1).fill_(False)
zero_tensor = self.host.LongTensor(beam_size, 1).fill_(end_token)
reverse_inds = {}
# Generate beams until max_len time steps.
for step in range(max_decoder_len - 1):
if self.params["text_encoder"] == "transformer":
beams, tokens_list = self._backtrack_beams(beams, reverse_inds)
beam_tokens = torch.cat(tokens_list, dim=0).transpose(0, 1)
beam_tokens_embed = self.word_embed_net(beam_tokens)
if self.params["encoder"] != "pretrained_transformer":
dec_embeds = self.pos_encoder(beam_tokens_embed).transpose(0, 1)
output = self.decoder_unit(
dec_embeds,
hidden_state_expand,
tgt_mask=self.no_peek_mask[: step + 1, : step + 1],
memory_key_padding_mask=enc_pad_mask_expand,
)
logits = self.inv_word_net(output[-1])
else:
outputs = self.decoder_unit(
inputs_embeds=beam_tokens_embed,
encoder_hidden_states=hidden_state_expand.transpose(0, 1),
encoder_attention_mask=~enc_pad_mask_expand,
)
logits = outputs[0][:, -1, :]
elif self.params["text_encoder"] == "lstm":
beam_tokens = beams[step - 1].t()
beam_tokens_embed = self.word_embed_net(beam_tokens)
# Append dialog context if exists.
if self.params["encoder"] in self.DIALOG_CONTEXT_ENCODERS:
dialog_context = encoder_output["dialog_context"]
beam_tokens_embed = torch.cat(
[dialog_context.repeat(beam_size, 1, 1), beam_tokens_embed],
dim=-1,
)
# Use bahdanau attention over encoder hidden states.
if (
self.params["use_bahdanau_attention"]
and self.params["encoder"] != "tf_idf"
):
previous_state = hidden_state[0][-1].unsqueeze(1)
att_logits = previous_state * encoder_states_proj
att_logits = att_logits.sum(dim=-1, keepdim=True)
# Use encoder mask to replace <pad> with -Inf.
att_logits.masked_fill_(enc_mask, float("-Inf"))
att_wts = nn.functional.softmax(att_logits, dim=1)
context = (encoder_states * att_wts).sum(1, keepdim=True)
# Run through LSTM.
step_in = torch.cat([context, beam_tokens_embed], dim=-1)
decoder_output, new_state = self.decoder_unit(
step_in, hidden_state
)
output = torch.cat([decoder_output, context], dim=-1)
else:
output, new_state = self.decoder_unit(
beam_tokens_embed, hidden_state
)
logits = self.inv_word_net(output).squeeze(1)
log_probs = nn.functional.log_softmax(logits, dim=-1)
# Compute the new beam scores.
alive = finished_beams.eq(0).float()
if LENGTH_NORM:
# Add (current log probs / (step + 1))
cur_weight = alive / (step + 1)
# Add (previous log probs * (t/t+1) ) <- Mean update
prev_weight = alive * step / (step + 1)
else:
# No length normalization.
cur_weight = alive
prev_weight = alive
# Compute the new beam extensions.
if step == 0:
# For the first step, make all but first beam
# probabilities -inf.
log_probs[1:, :] = float("-inf")
new_scores = log_probs * cur_weight + beam_scores * prev_weight
finished_beam_scores = beam_scores * finished_beams.float()
new_scores.scatter_add_(1, zero_tensor, finished_beam_scores)
# Finished beams scores are set to -inf for all words but one.
new_scores.masked_fill_(new_scores.eq(0), float("-inf"))
num_candidates = new_scores.shape[-1]
new_scores_flat = new_scores.view(1, -1)
beam_scores, top_inds_flat = torch.topk(new_scores_flat, beam_size)
beam_scores = beam_scores.t()
top_beam_inds = (top_inds_flat / num_candidates).squeeze(0)
top_tokens = top_inds_flat % num_candidates
# Prepare for next step.
beams[step] = top_tokens
reverse_inds[step] = top_beam_inds
finished_beams = finished_beams[top_beam_inds]
if self.params["text_encoder"] == "lstm":
hidden_state = tuple(
ii.index_select(1, top_beam_inds) for ii in new_state
)
# Update if any of the latest beams are finished, ie, have <END>.
# new_finished_beams = beams[step].eq(end_token)
new_finished_beams = beams[step].eq(end_token).type(self.host.ByteTensor)
finished_beams = finished_beams | new_finished_beams.t()
if torch.sum(finished_beams).item() == beam_size:
break
# Backtrack the beam through indices.
beams, tokens_list = self._backtrack_beams(beams, reverse_inds)
# Add an <END> token at the end.
tokens_list.append(self.host.LongTensor(1, beam_size).fill_(end_token))
sorted_beam_tokens = torch.cat(tokens_list, 0).t()
sorted_beam_lengths = sorted_beam_tokens.ne(end_token).long().sum(dim=1)
# Trim all the top beams.
top_beams = []
for index in range(beam_size):
beam_length = sorted_beam_lengths[index].view(-1)
beam = sorted_beam_tokens[index].view(-1, 1)[1:beam_length]
top_beams.append(beam)
return {"top_beams": top_beams}