in Models/exprsynth/seqdecoder.py [0:0]
def __make_train_model(self):
rnn_cell = make_rnn_cell(self.hyperparameters['decoder_rnn_layer_num'],
self.hyperparameters['decoder_rnn_cell_type'],
hidden_size=self.hyperparameters['decoder_rnn_hidden_size'],
dropout_keep_rate=self.placeholders['dropout_keep_rate'],
)
initial_cell_state = self.__make_decoder_rnn_initial_state(self.ops['decoder_initial_state'], rnn_cell)
# Reorg data from [batch, time, emb_dim] to [time, batch, emb_dim], and build corresponding tensor array:
target_tokens_by_time = tf.transpose(self.placeholders['target_token_ids'], perm=[1, 0])
target_tokens_ta = tf.TensorArray(dtype=tf.int32,
size=self.hyperparameters['decoder_max_target_length'],
name="target_tokens_embedded_ta",
)
target_tokens_ta = target_tokens_ta.unstack(target_tokens_by_time)
# First, initialise loop variables:
one_one_per_sample = tf.ones_like(self.placeholders['target_token_ids'][:,0])
initial_input = one_one_per_sample * self.metadata['decoder_token_vocab'].get_id_or_unk(START_TOKEN)
initial_input = tf.nn.embedding_lookup(self.parameters['decoder_token_embedding'], initial_input)
end_token = one_one_per_sample * self.metadata['decoder_token_vocab'].get_id_or_unk(END_TOKEN)
empty_output_logits_ta = tf.TensorArray(dtype=tf.float32,
size=self.hyperparameters['decoder_max_target_length'],
name="output_logits_ta",
)
def condition(time_unused, output_logits_ta_unused, decoder_state_unused, last_output_tok_embedded_unused, finished):
return tf.logical_not(tf.reduce_all(finished))
def body(step, output_logits_ta, decoder_state, last_output_tok_embedded, finished):
next_step = step + 1
# Use the RNN to decode one more tok:
cur_output, next_decoder_state = rnn_cell(last_output_tok_embedded, decoder_state)
cur_rnn_output_logits = self.parameters['decoder_output_projection'](cur_output)
# Decide if we're done everywhere:
next_finished = tf.logical_or(finished, next_step >= self.hyperparameters['decoder_max_target_length'])
# Decide next token: If in training, use the next target token...
all_next_finished = tf.reduce_all(next_finished)
cur_output_tok = tf.cond(all_next_finished,
lambda: end_token,
lambda: target_tokens_ta.read(step))
cur_output_tok_embedded = tf.nn.embedding_lookup(self.parameters['decoder_token_embedding'],
cur_output_tok)
# Write out the collected wisdom:
output_logits_ta = output_logits_ta.write(step, cur_rnn_output_logits)
return (next_step, output_logits_ta, next_decoder_state, cur_output_tok_embedded, next_finished)
(_, final_output_logits_ta, _, _, _) = \
tf.while_loop(condition,
body,
loop_vars=[tf.constant(0, dtype=tf.int32),
empty_output_logits_ta,
initial_cell_state,
initial_input,
tf.zeros_like(self.placeholders['target_token_ids'][:,0], dtype=tf.bool),
],
parallel_iterations=1
)
output_logits_by_time = final_output_logits_ta.stack()
self.ops['decoder_output_logits'] = tf.transpose(output_logits_by_time, perm=[1, 0, 2])
self.ops['decoder_output_probs'] = tf.nn.softmax(self.ops['decoder_output_logits'])
# Produce loss:
outputs_correct_crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.placeholders['target_token_ids'],
logits=self.ops['decoder_output_logits'])
masked_outputs_correct_crossent = outputs_correct_crossent * self.placeholders['target_token_ids_mask']
decoder_loss = tf.reduce_sum(masked_outputs_correct_crossent)
self.ops['log_probs'] = -decoder_loss
# Normalize by batch size:
self.ops['loss'] = decoder_loss / tf.to_float(self.placeholders['batch_size'])