in models_mnist/generator_attnet.py [0:0]
def _build_decoder(self, use_gt_layout, gt_layout_batch, scope='decoder',
reuse=None):
# The main difference from before is that the decoders now takes another
# input (the attention) when computing the next step
# T_max is the maximum length of decoded sequence (including <eos>)
#
# This function is for decoding only. It performs greedy search or sampling.
# the first input is <go> (its embedding vector) and the subsequent inputs
# are the outputs from previous time step
# num_vocab does not include <go>
#
# use_gt_layout is None or a bool tensor, and gt_layout_batch is a tenwor
# with shape [T_max, N].
# If use_gt_layout is not None, then when use_gt_layout is true, predict
# exactly the tokens in gt_layout_batch, regardless of actual probability.
# Otherwise, if sampling is True, sample from the token probability
# If sampling is False, do greedy decoding (beam size 1)
N = self.N
encoder_states = self.encoder_states
T_max = self.T_decoder
lstm_dim = self.lstm_dim
num_layers = self.num_layers
apply_dropout = self.decoder_dropout
EOS_token = self.EOS_token
sampling = self.decoder_sampling
with tf.variable_scope(scope, reuse=reuse):
embedding_mat = tf.get_variable('embedding_mat',
[self.decoder_num_vocab, self.decoder_embed_dim])
# we use a separate embedding for <go>, as it is only used in the
# beginning of the sequence
go_embedding = tf.get_variable('go_embedding', [1, self.decoder_embed_dim])
with tf.variable_scope('att_prediction'):
v = tf.get_variable('v', [lstm_dim])
W_a = tf.get_variable('weights', [lstm_dim, lstm_dim],
initializer=tf.contrib.layers.xavier_initializer())
b_a = tf.get_variable('biases', lstm_dim,
initializer=tf.constant_initializer(0.))
# The parameters to predict the next token
with tf.variable_scope('token_prediction'):
W_y = tf.get_variable('weights', [lstm_dim*2, self.decoder_num_vocab],
initializer=tf.contrib.layers.xavier_initializer())
b_y = tf.get_variable('biases', self.decoder_num_vocab,
initializer=tf.constant_initializer(0.))
# Attentional decoding
# Loop function is called at time t BEFORE the cell execution at time t,
# and its next_input is used as the input at time t (not t+1)
# c.f. https://www.tensorflow.org/api_docs/python/tf/nn/raw_rnn
mask_range = tf.reshape(tf.range(self.decoder_num_vocab, dtype=tf.int32),
[1, -1])
if use_gt_layout is not None:
gt_layout_mult = tf.cast(use_gt_layout, tf.int32)
pred_layout_mult = 1 - gt_layout_mult
def loop_fn(time, cell_output, cell_state, loop_state):
if cell_output is None: # time == 0
next_cell_state = encoder_states
next_input = tf.tile(go_embedding, to_T([N, 1]))
else: # time > 0
next_cell_state = cell_state
# compute the attention map over the input sequence
# a_raw has shape [T, N, 1]
att_raw = tf.reduce_sum(
tf.tanh(tf.nn.xw_plus_b(cell_output, W_a, b_a) +
self.encoder_h_transformed) * v,
axis=2, keep_dims=True)
# softmax along the first dimension (T) over not finished examples
# att has shape [T, N, 1]
att = tf.nn.softmax(att_raw, dim=0)*self.seq_not_finished
att = att / tf.reduce_sum(att + 1e-10, axis=0, keep_dims=True)
# d has shape [N, lstm_dim]
d2 = tf.reduce_sum(att*self.encoder_outputs, axis=0)
# token_scores has shape [N, num_vocab]
token_scores = tf.nn.xw_plus_b(
tf.concat([cell_output, d2], axis=1),
W_y, b_y)
decoding_state = loop_state[2]
# token_validity has shape [N, num_vocab]
token_validity = _get_valid_tokens(decoding_state, self.W, self.b)
token_validity.set_shape([None, self.decoder_num_vocab])
if use_gt_layout is not None:
# when there's ground-truth layout, do not re-normalize prob
# and treat all tokens as valid
token_validity = tf.logical_or(token_validity, use_gt_layout)
validity_mult = tf.cast(token_validity, tf.float32)
# predict the next token (behavior depending on parameters)
if sampling:
token_scores_valid = token_scores - (1-validity_mult) * 50
# TODO:debug
sampled_token = tf.cast(tf.reshape(
tf.multinomial(token_scores_valid/self.temperature, 1), [-1]), tf.int32)
# make sure that the predictions are ALWAYS valid
# (it can be invalid with very small prob)
# If not, just fall back to min cases
# pred_mask has shape [N, num_vocab]
sampled_mask = tf.equal(mask_range, tf.reshape(sampled_token, [-1, 1]))
is_sampled_valid = tf.reduce_any(
tf.logical_and(sampled_mask, token_validity),
axis=1)
# Fall back to max score (no sampling)
min_score = tf.reduce_min(token_scores)
token_scores_valid = tf.where(token_validity, token_scores,
tf.ones_like(token_scores)*(min_score-1))
max_score_token = tf.cast(tf.argmax(token_scores_valid, 1), tf.int32)
predicted_token = tf.where(is_sampled_valid, sampled_token, max_score_token)
else:
min_score = tf.reduce_min(token_scores)
token_scores_valid = tf.where(token_validity, token_scores,
tf.ones_like(token_scores)*(min_score-1))
# predicted_token has shape [N]
predicted_token = tf.cast(tf.argmax(token_scores_valid, 1), tf.int32)
if use_gt_layout is not None:
predicted_token = (gt_layout_batch[time-1] * gt_layout_mult
+ predicted_token * pred_layout_mult)
# a robust version of softmax
# all_token_probs has shape [N, num_vocab]
all_token_probs = tf.nn.softmax(token_scores) * validity_mult
# tf.check_numerics(all_token_probs, 'NaN/Inf before div')
all_token_probs = all_token_probs / tf.reduce_sum(all_token_probs + 1e-10, axis=1, keep_dims=True)
# tf.check_numerics(all_token_probs, 'NaN/Inf after div')
# mask has shape [N, num_vocab]
mask = tf.equal(mask_range, tf.reshape(predicted_token, [-1, 1]))
# token_prob has shape [N], the probability of the predicted token
# although token_prob is not needed for predicting the next token
# it is needed in output (for policy gradient training)
# [N, num_vocab]
token_prob = tf.reduce_sum(all_token_probs * tf.cast(mask, tf.float32), axis=1)
# tf.assert_positive(token_prob)
neg_entropy = tf.reduce_sum(
all_token_probs * tf.log(all_token_probs + (1-validity_mult) + 1e-10),
axis=1)
# update states
updated_decoding_state = _update_decoding_state(
decoding_state, predicted_token, self.P)
# the prediction is from the cell output of the last step
# timestep (t-1), feed it as input into timestep t
next_input = tf.nn.embedding_lookup(embedding_mat, predicted_token)
elements_finished = tf.greater_equal(time, T_max)
# loop_state is a 5-tuple, representing
# 1) the predicted_tokens
# 2) the prob of predicted_tokens
# 3) the decoding state (used for validity)
# 4) the negative entropy of policy (accumulated across timesteps)
# 5) the attention
if loop_state is None: # time == 0
# Write the predicted token into the output
predicted_token_array = tf.TensorArray(dtype=tf.int32, size=T_max,
infer_shape=False)
token_prob_array = tf.TensorArray(dtype=tf.float32, size=T_max,
infer_shape=False)
init_decoding_state = tf.tile(to_T([[0, 0, T_max]], dtype=tf.int32), to_T([N, 1]))
att_array = tf.TensorArray(dtype=tf.float32, size=T_max,
infer_shape=False)
next_loop_state = (predicted_token_array,
token_prob_array,
init_decoding_state,
tf.zeros(to_T([N]), dtype=tf.float32),
att_array)
else: # time > 0
t_write = time-1
next_loop_state = (loop_state[0].write(t_write, predicted_token),
loop_state[1].write(t_write, token_prob),
updated_decoding_state,
loop_state[3] + neg_entropy,
loop_state[4].write(t_write, att))
return (elements_finished, next_input, next_cell_state, cell_output,
next_loop_state)
# The RNN
cell = _get_lstm_cell(num_layers, lstm_dim, apply_dropout)
_, _, decodes_ta = tf.nn.raw_rnn(cell, loop_fn, scope='lstm')
predicted_tokens = decodes_ta[0].stack()
token_probs = decodes_ta[1].stack()
neg_entropy = decodes_ta[3]
# atts has shape [T_decoder, T_encoder, N, 1]
atts = decodes_ta[4].stack()
# static dimension recast
atts = tf.reshape(atts, [self.T_decoder, self.T_encoder, -1, 1])
self.atts = atts
# word_vec has shape [T_decoder, N, 1]
word_vecs = tf.reduce_sum(atts*self.embedded_input_seq, axis=1)
predicted_tokens.set_shape([None, None])
token_probs.set_shape([None, None])
neg_entropy.set_shape([None])
#word_vecs.set_shape([None, None, self.encoder_embed_dim])
# static shapes
word_vecs.set_shape([self.T_decoder, None, self.encoder_embed_dim])
self.predicted_tokens = predicted_tokens
self.token_probs = token_probs
self.neg_entropy = neg_entropy
self.word_vecs = word_vecs