def forward_beamsearch_single()

in mm_action_prediction/models/decoder.py [0:0]


    def forward_beamsearch_single(self, batch, encoder_output, mode_params):
        """Evaluates the model using beam search with batch size 1.

        NOTE: Current implementation only supports beam search for batch size 1
              and for RNN text_encoder (will be extended for transformers)

        Args:
            batch: Dictionary of inputs, with batch size of 1
            beam_size: Number of beams

        Returns:
            top_beams: Dictionary of top beams
        """
        # Initializations and aliases.
        # Tensors are either on CPU or GPU.
        LENGTH_NORM = True
        self.host = torch.cuda if self.params["use_gpu"] else torch
        end_token = self.params["end_token"]
        start_token = self.params["start_token"]
        beam_size = mode_params["beam_size"]
        max_decoder_len = self.params["max_decoder_len"]

        if self.params["text_encoder"] == "transformer":
            hidden_state = encoder_output["hidden_states_all"].transpose(0, 1)
            max_enc_len, batch_size, enc_embed_size = hidden_state.shape
            hidden_state_expand = hidden_state.expand(
                max_enc_len, beam_size, enc_embed_size
            )
            enc_pad_mask = batch["user_utt"] == batch["pad_token"]
            enc_pad_mask = support.flatten(enc_pad_mask, 1, 1)
            enc_pad_mask_expand = enc_pad_mask.expand(beam_size, max_enc_len)
            if (
                self.no_peek_mask is None
                or self.no_peek_mask.size(0) != max_decoder_len
            ):
                self.no_peek_mask = self._generate_no_peek_mask(max_decoder_len)
        elif self.params["text_encoder"] == "lstm":
            hidden_state = encoder_output["hidden_state"]

        if (
            self.params["use_bahdanau_attention"]
            and self.params["encoder"] != "tf_idf"
            and self.params["text_encoder"] == "lstm"
        ):
            encoder_states = encoder_output["hidden_states_all"]
            encoder_states_proj = self.attention_net(encoder_states)
            enc_mask = (batch["user_utt"] == batch["pad_token"]).unsqueeze(-1)
            enc_mask = support.flatten(enc_mask, 1, 1)

        # Per instance initializations.
        # Copy the hidden state beam_size number of times.
        if hidden_state is not None:
            hidden_state = [ii.repeat(1, beam_size, 1) for ii in hidden_state]
        beams = {-1: self.host.LongTensor(1, beam_size).fill_(start_token)}
        beam_scores = self.host.FloatTensor(beam_size, 1).fill_(0.)
        finished_beams = self.host.ByteTensor(beam_size, 1).fill_(False)
        zero_tensor = self.host.LongTensor(beam_size, 1).fill_(end_token)
        reverse_inds = {}
        # Generate beams until max_len time steps.
        for step in range(max_decoder_len - 1):
            if self.params["text_encoder"] == "transformer":
                beams, tokens_list = self._backtrack_beams(beams, reverse_inds)
                beam_tokens = torch.cat(tokens_list, dim=0).transpose(0, 1)
                beam_tokens_embed = self.word_embed_net(beam_tokens)

                if self.params["encoder"] != "pretrained_transformer":
                    dec_embeds = self.pos_encoder(beam_tokens_embed).transpose(0, 1)
                    output = self.decoder_unit(
                        dec_embeds,
                        hidden_state_expand,
                        tgt_mask=self.no_peek_mask[: step + 1, : step + 1],
                        memory_key_padding_mask=enc_pad_mask_expand,
                    )
                    logits = self.inv_word_net(output[-1])
                else:
                    outputs = self.decoder_unit(
                        inputs_embeds=beam_tokens_embed,
                        encoder_hidden_states=hidden_state_expand.transpose(0, 1),
                        encoder_attention_mask=~enc_pad_mask_expand,
                    )
                    logits = outputs[0][:, -1, :]

            elif self.params["text_encoder"] == "lstm":
                beam_tokens = beams[step - 1].t()
                beam_tokens_embed = self.word_embed_net(beam_tokens)
                # Append dialog context if exists.
                if self.params["encoder"] in self.DIALOG_CONTEXT_ENCODERS:
                    dialog_context = encoder_output["dialog_context"]
                    beam_tokens_embed = torch.cat(
                        [dialog_context.repeat(beam_size, 1, 1), beam_tokens_embed],
                        dim=-1,
                    )
                # Use bahdanau attention over encoder hidden states.
                if (
                    self.params["use_bahdanau_attention"]
                    and self.params["encoder"] != "tf_idf"
                ):
                    previous_state = hidden_state[0][-1].unsqueeze(1)
                    att_logits = previous_state * encoder_states_proj
                    att_logits = att_logits.sum(dim=-1, keepdim=True)
                    # Use encoder mask to replace <pad> with -Inf.
                    att_logits.masked_fill_(enc_mask, float("-Inf"))
                    att_wts = nn.functional.softmax(att_logits, dim=1)
                    context = (encoder_states * att_wts).sum(1, keepdim=True)
                    # Run through LSTM.
                    step_in = torch.cat([context, beam_tokens_embed], dim=-1)
                    decoder_output, new_state = self.decoder_unit(
                        step_in, hidden_state
                    )
                    output = torch.cat([decoder_output, context], dim=-1)
                else:
                    output, new_state = self.decoder_unit(
                        beam_tokens_embed, hidden_state
                    )
                logits = self.inv_word_net(output).squeeze(1)
            log_probs = nn.functional.log_softmax(logits, dim=-1)
            # Compute the new beam scores.
            alive = finished_beams.eq(0).float()
            if LENGTH_NORM:
                # Add (current log probs / (step + 1))
                cur_weight = alive / (step + 1)
                # Add (previous log probs * (t/t+1) ) <- Mean update
                prev_weight = alive * step / (step + 1)
            else:
                # No length normalization.
                cur_weight = alive
                prev_weight = alive

            # Compute the new beam extensions.
            if step == 0:
                # For the first step, make all but first beam
                # probabilities -inf.
                log_probs[1:, :] = float("-inf")
            new_scores = log_probs * cur_weight + beam_scores * prev_weight
            finished_beam_scores = beam_scores * finished_beams.float()
            new_scores.scatter_add_(1, zero_tensor, finished_beam_scores)
            # Finished beams scores are set to -inf for all words but one.
            new_scores.masked_fill_(new_scores.eq(0), float("-inf"))

            num_candidates = new_scores.shape[-1]
            new_scores_flat = new_scores.view(1, -1)
            beam_scores, top_inds_flat = torch.topk(new_scores_flat, beam_size)
            beam_scores = beam_scores.t()
            top_beam_inds = (top_inds_flat / num_candidates).squeeze(0)
            top_tokens = top_inds_flat % num_candidates

            # Prepare for next step.
            beams[step] = top_tokens
            reverse_inds[step] = top_beam_inds
            finished_beams = finished_beams[top_beam_inds]
            if self.params["text_encoder"] == "lstm":
                hidden_state = tuple(
                    ii.index_select(1, top_beam_inds) for ii in new_state
                )

            # Update if any of the latest beams are finished, ie, have <END>.
            # new_finished_beams = beams[step].eq(end_token)
            new_finished_beams = beams[step].eq(end_token).type(self.host.ByteTensor)
            finished_beams = finished_beams | new_finished_beams.t()
            if torch.sum(finished_beams).item() == beam_size:
                break

        # Backtrack the beam through indices.
        beams, tokens_list = self._backtrack_beams(beams, reverse_inds)
        # Add an <END> token at the end.
        tokens_list.append(self.host.LongTensor(1, beam_size).fill_(end_token))

        sorted_beam_tokens = torch.cat(tokens_list, 0).t()
        sorted_beam_lengths = sorted_beam_tokens.ne(end_token).long().sum(dim=1)
        # Trim all the top beams.
        top_beams = []
        for index in range(beam_size):
            beam_length = sorted_beam_lengths[index].view(-1)
            beam = sorted_beam_tokens[index].view(-1, 1)[1:beam_length]
            top_beams.append(beam)
        return {"top_beams": top_beams}