in Models/exprsynth/seqdecoder.py [0:0]
def generate_suggestions_for_one_sample(self,
test_sample: Dict[str, Any],
test_sample_encoded: tf.Tensor,
beam_size: int=3,
max_decoding_steps: int=100) -> ModelTestResult:
def expand_sequence(decoder_info: SeqDecodingInformation) -> List[SeqDecodingInformation]:
last_tok = decoder_info.sequence[-1]
if last_tok == END_TOKEN:
return [decoder_info]
last_tok_id = self.metadata['decoder_token_vocab'].get_id_or_unk(last_tok)
rnn_one_step_data_dict = {
self.placeholders['rnn_hidden_state']: decoder_info.rnn_state,
self.placeholders['rnn_input_tok_id']: [last_tok_id],
self.placeholders['dropout_keep_rate']: 1.0,
}
(output_probs, next_state) = \
self.sess.run([self.ops['one_rnn_decoder_step_output'], self.ops['one_rnn_decoder_step_state']],
feed_dict=rnn_one_step_data_dict)
next_tok_indices = pick_indices_from_probs(output_probs[0, :], beam_size)
result = []
for next_tok_idx in next_tok_indices:
next_tok = self.metadata['decoder_token_vocab'].id_to_token[next_tok_idx]
next_tok_prob = output_probs[0,next_tok_idx]
new_decoder_info = SeqDecodingInformation(rnn_state=next_state,
sequence=list(decoder_info.sequence) + [next_tok],
seq_logprob=decoder_info.seq_logprob + np.log(next_tok_prob))
result.append(new_decoder_info)
return result
rnn_cell = make_rnn_cell(self.hyperparameters['decoder_rnn_layer_num'],
self.hyperparameters['decoder_rnn_cell_type'],
hidden_size=self.hyperparameters['decoder_rnn_hidden_size'],
dropout_keep_rate=1,
)
initial_cell_state = self.__make_decoder_rnn_initial_state(test_sample_encoded, rnn_cell)
initial_decoder_info = SeqDecodingInformation(rnn_state=initial_cell_state,
sequence=[START_TOKEN],
seq_logprob=0.0)
beams = [initial_decoder_info] # type: List[SeqDecodingInformation]
number_of_steps = 0
while number_of_steps < max_decoding_steps and any(b.sequence[-1] != END_TOKEN for b in beams):
new_beams = [new_beam
for beam in beams
for new_beam in expand_sequence(beam)]
beams = sorted(new_beams, key=lambda b: -b.seq_logprob)[:beam_size] # Pick top K beams
self.test_log("Groundtruth: %s" % (" ".join(test_sample['target_tokens']),))
all_predictions = [] # type: List[Tuple[List[str], float]]
for (k, beam_info) in enumerate(beams):
beam_info.sequence.pop() # Remove END_TOKEN
beam_info.sequence.pop(0) # Remove START_TOKEN
kth_result = beam_info.sequence
all_predictions.append((kth_result, np.exp(beam_info.seq_logprob)))
self.test_log(" @%i Prob. %.3f: %s" % (k+1, np.exp(beam_info.seq_logprob), " ".join(kth_result)))
if len(beams) == 0:
print("No beams finished!")
return ModelTestResult(test_sample['target_tokens'], all_predictions)