def generate_suggestions_for_one_sample()

in Models/exprsynth/seqdecoder.py [0:0]


    def generate_suggestions_for_one_sample(self,
                                            test_sample: Dict[str, Any],
                                            test_sample_encoded: tf.Tensor,
                                            beam_size: int=3,
                                            max_decoding_steps: int=100) -> ModelTestResult:

        def expand_sequence(decoder_info: SeqDecodingInformation) -> List[SeqDecodingInformation]:
            last_tok = decoder_info.sequence[-1]
            if last_tok == END_TOKEN:
                return [decoder_info]

            last_tok_id = self.metadata['decoder_token_vocab'].get_id_or_unk(last_tok)
            rnn_one_step_data_dict = {
                self.placeholders['rnn_hidden_state']: decoder_info.rnn_state,
                self.placeholders['rnn_input_tok_id']: [last_tok_id],
                self.placeholders['dropout_keep_rate']: 1.0,
            }

            (output_probs, next_state) = \
                self.sess.run([self.ops['one_rnn_decoder_step_output'], self.ops['one_rnn_decoder_step_state']],
                              feed_dict=rnn_one_step_data_dict)
            next_tok_indices = pick_indices_from_probs(output_probs[0, :], beam_size)

            result = []
            for next_tok_idx in next_tok_indices:
                next_tok = self.metadata['decoder_token_vocab'].id_to_token[next_tok_idx]
                next_tok_prob = output_probs[0,next_tok_idx]
                new_decoder_info = SeqDecodingInformation(rnn_state=next_state,
                                                          sequence=list(decoder_info.sequence) + [next_tok],
                                                          seq_logprob=decoder_info.seq_logprob + np.log(next_tok_prob))
                result.append(new_decoder_info)

            return result

        rnn_cell = make_rnn_cell(self.hyperparameters['decoder_rnn_layer_num'],
                                 self.hyperparameters['decoder_rnn_cell_type'],
                                 hidden_size=self.hyperparameters['decoder_rnn_hidden_size'],
                                 dropout_keep_rate=1,
                                 )
        initial_cell_state = self.__make_decoder_rnn_initial_state(test_sample_encoded, rnn_cell)
        initial_decoder_info = SeqDecodingInformation(rnn_state=initial_cell_state,
                                                      sequence=[START_TOKEN],
                                                      seq_logprob=0.0)
        beams = [initial_decoder_info]  # type: List[SeqDecodingInformation]
        number_of_steps = 0
        while number_of_steps < max_decoding_steps and any(b.sequence[-1] != END_TOKEN for b in beams):
            new_beams = [new_beam
                         for beam in beams
                         for new_beam in expand_sequence(beam)]
            beams = sorted(new_beams, key=lambda b: -b.seq_logprob)[:beam_size]  # Pick top K beams

        self.test_log("Groundtruth: %s" % (" ".join(test_sample['target_tokens']),))

        all_predictions = []  # type: List[Tuple[List[str], float]]
        for (k, beam_info) in enumerate(beams):
            beam_info.sequence.pop()  # Remove END_TOKEN
            beam_info.sequence.pop(0)  # Remove START_TOKEN
            kth_result = beam_info.sequence
            all_predictions.append((kth_result, np.exp(beam_info.seq_logprob)))
            self.test_log("  @%i Prob. %.3f: %s" % (k+1, np.exp(beam_info.seq_logprob), " ".join(kth_result)))

        if len(beams) == 0:
            print("No beams finished!")

        return ModelTestResult(test_sample['target_tokens'], all_predictions)