in mm_action_prediction/models/decoder.py [0:0]
def forward(self, batch, encoder_output):
"""Forward pass through the decoder.
Args:
batch: Dict of batch variables.
encoder_output: Dict of outputs from the encoder.
Returns:
decoder_outputs: Dict of outputs from the forward pass.
"""
# Flatten for history_agnostic encoder.
batch_size, num_rounds, max_length = batch["assist_in"].shape
decoder_in = support.flatten(batch["assist_in"], batch_size, num_rounds)
decoder_out = support.flatten(batch["assist_out"], batch_size, num_rounds)
decoder_len = support.flatten(batch["assist_in_len"], batch_size, num_rounds)
word_embeds_dec = self.word_embed_net(decoder_in)
if self.params["encoder"] in self.DIALOG_CONTEXT_ENCODERS:
dialog_context = support.flatten(
encoder_output["dialog_context"], batch_size, num_rounds
).unsqueeze(1)
dialog_context = dialog_context.expand(-1, max_length, -1)
decoder_steps_in = torch.cat([dialog_context, word_embeds_dec], -1)
else:
decoder_steps_in = word_embeds_dec
# Encoder states conditioned on action outputs, if need be.
if self.params["use_action_output"]:
action_out = encoder_output["action_output_all"].unsqueeze(1)
time_steps = encoder_output["hidden_states_all"].shape[1]
fusion_out = torch.cat(
[
encoder_output["hidden_states_all"],
action_out.expand(-1, time_steps, -1),
],
dim=-1,
)
encoder_output["hidden_states_all"] = self.action_fusion_net(fusion_out)
if self.params["text_encoder"] == "transformer":
# Check the status of no_peek_mask.
if self.no_peek_mask is None or self.no_peek_mask.size(0) != max_length:
self.no_peek_mask = self._generate_no_peek_mask(max_length)
hidden_state = encoder_output["hidden_states_all"]
enc_pad_mask = batch["user_utt"] == batch["pad_token"]
enc_pad_mask = support.flatten(enc_pad_mask, batch_size, num_rounds)
dec_pad_mask = batch["assist_in"] == batch["pad_token"]
dec_pad_mask = support.flatten(dec_pad_mask, batch_size, num_rounds)
if self.params["encoder"] != "pretrained_transformer":
dec_embeds = self.pos_encoder(decoder_steps_in).transpose(0, 1)
outputs = self.decoder_unit(
dec_embeds,
hidden_state.transpose(0, 1),
memory_key_padding_mask=enc_pad_mask,
tgt_mask=self.no_peek_mask,
tgt_key_padding_mask=dec_pad_mask,
)
outputs = outputs.transpose(0, 1)
else:
outputs = self.decoder_unit(
inputs_embeds=decoder_steps_in,
attention_mask=~dec_pad_mask,
encoder_hidden_states=hidden_state,
encoder_attention_mask=~enc_pad_mask,
)
outputs = outputs[0]
else:
hidden_state = encoder_output["hidden_state"]
if self.params["encoder"] == "tf_idf":
hidden_state = None
# If Bahdahnue attention is to be used.
if (
self.params["use_bahdanau_attention"]
and self.params["encoder"] != "tf_idf"
):
encoder_states = encoder_output["hidden_states_all"]
max_decoder_len = min(
decoder_in.shape[1], self.params["max_decoder_len"]
)
encoder_states_proj = self.attention_net(encoder_states)
enc_mask = (batch["user_utt"] == batch["pad_token"]).unsqueeze(-1)
enc_mask = support.flatten(enc_mask, batch_size, num_rounds)
outputs = []
for step in range(max_decoder_len):
previous_state = hidden_state[0][-1].unsqueeze(1)
att_logits = previous_state * encoder_states_proj
att_logits = att_logits.sum(dim=-1, keepdim=True)
# Use encoder mask to replace <pad> with -Inf.
att_logits.masked_fill_(enc_mask, float("-Inf"))
att_wts = nn.functional.softmax(att_logits, dim=1)
context = (encoder_states * att_wts).sum(1, keepdim=True)
# Run through LSTM.
concat_in = [context, decoder_steps_in[:, step : step + 1, :]]
step_in = torch.cat(concat_in, dim=-1)
decoder_output, hidden_state = self.decoder_unit(
step_in, hidden_state
)
concat_out = torch.cat([decoder_output, context], dim=-1)
outputs.append(concat_out)
outputs = torch.cat(outputs, dim=1)
else:
outputs = rnn.dynamic_rnn(
self.decoder_unit,
decoder_steps_in,
decoder_len,
init_state=hidden_state,
)
if self.params["encoder"] == "pretrained_transformer":
output_logits = outputs
else:
# Logits over vocabulary.
output_logits = self.inv_word_net(outputs)
# Mask out the criterion while summing.
pad_mask = support.flatten(batch["assist_mask"], batch_size, num_rounds)
loss_token = self.criterion(output_logits.transpose(1, 2), decoder_out)
loss_token.masked_fill_(pad_mask, 0.0)
return {"loss_token": loss_token, "pad_mask": pad_mask}