def _build_decoder()

in models_vd/generator_attnet.py [0:0]


  def _build_decoder(self, use_gt_layout, gt_layout_batch, scope='decoder',
    reuse=None):
    # The main difference from before is that the decoders now takes another
    # input (the attention) when computing the next step
    # T_max is the maximum length of decoded sequence (including <eos>)
    #
    # This function is for decoding only. It performs greedy search or sampling.
    # the first input is <go> (its embedding vector) and the subsequent inputs
    # are the outputs from previous time step
    # num_vocab does not include <go>
    #
    # use_gt_layout is None or a bool tensor, and gt_layout_batch is a tenwor
    # with shape [T_max, N].
    # If use_gt_layout is not None, then when use_gt_layout is true, predict
    # exactly the tokens in gt_layout_batch, regardless of actual probability.
    # Otherwise, if sampling is True, sample from the token probability
    # If sampling is False, do greedy decoding (beam size 1)
    N = self.N
    encoder_states = self.encoder_states
    T_max = self.T_decoder
    lstm_dim = self.lstm_dim
    num_layers = self.num_layers
    apply_dropout = self.decoder_dropout
    EOS_token = self.EOS_token
    sampling = self.decoder_sampling

    with tf.variable_scope(scope, reuse=reuse):
      embedding_mat = tf.get_variable('embedding_mat',
        [self.decoder_num_vocab, self.decoder_embed_dim])
      # we use a separate embedding for <go>, as it is only used in the
      # beginning of the sequence
      go_embedding = tf.get_variable('go_embedding', [1, self.decoder_embed_dim])

      with tf.variable_scope('att_prediction'):
        v = tf.get_variable('v', [lstm_dim])
        W_a = tf.get_variable('weights', [lstm_dim, lstm_dim],
          initializer=tf.contrib.layers.xavier_initializer())
        b_a = tf.get_variable('biases', lstm_dim,
          initializer=tf.constant_initializer(0.))

      # The parameters to predict the next token
      with tf.variable_scope('token_prediction'):
        W_y = tf.get_variable('weights', [lstm_dim*2, self.decoder_num_vocab],
          initializer=tf.contrib.layers.xavier_initializer())
        b_y = tf.get_variable('biases', self.decoder_num_vocab,
          initializer=tf.constant_initializer(0.))

      # Attentional decoding
      # Loop function is called at time t BEFORE the cell execution at time t,
      # and its next_input is used as the input at time t (not t+1)
      # c.f. https://www.tensorflow.org/api_docs/python/tf/nn/raw_rnn
      mask_range = tf.reshape(tf.range(self.decoder_num_vocab, dtype=tf.int32),
                              [1, -1])

      if use_gt_layout is not None:
        gt_layout_mult = tf.cast(use_gt_layout, tf.int32)
        pred_layout_mult = 1 - gt_layout_mult
      def loop_fn(time, cell_output, cell_state, loop_state):
        if cell_output is None:  # time == 0
          next_cell_state = encoder_states
          next_input = tf.tile(go_embedding, to_T([N, 1]))
        else:  # time > 0
          next_cell_state = cell_state

          # compute the attention map over the input sequence
          # a_raw has shape [T, N, 1]
          att_raw = tf.reduce_sum(
            tf.tanh(tf.nn.xw_plus_b(cell_output, W_a, b_a) +
                self.encoder_h_transformed) * v,
            axis=2, keep_dims=True)
          # softmax along the first dimension (T) over not finished examples
          # att has shape [T, N, 1]
          att = tf.nn.softmax(att_raw, dim=0)*self.seq_not_finished
          att = att / tf.reduce_sum(att + 1e-10, axis=0, keep_dims=True)
          # d has shape [N, lstm_dim]
          d2 = tf.reduce_sum(att*self.encoder_outputs, axis=0)

          # token_scores has shape [N, num_vocab]
          token_scores = tf.nn.xw_plus_b(
            tf.concat([cell_output, d2], axis=1),
            W_y, b_y)

          decoding_state = loop_state[2]
          # token_validity has shape [N, num_vocab]
          token_validity = _get_valid_tokens(decoding_state, self.W, self.b)
          token_validity.set_shape([None, self.decoder_num_vocab])
          if use_gt_layout is not None:
            # when there's ground-truth layout, do not re-normalize prob
            # and treat all tokens as valid
            token_validity = tf.logical_or(token_validity, use_gt_layout)

          validity_mult = tf.cast(token_validity, tf.float32)

          # predict the next token (behavior depending on parameters)
          if sampling:
            token_scores_valid = token_scores - (1-validity_mult) * 50
            # TODO:debug
            sampled_token = tf.cast(tf.reshape(
                tf.multinomial(token_scores_valid/self.temperature, 1), [-1]), tf.int32)

            # make sure that the predictions are ALWAYS valid 
            # (it can be invalid with very small prob)
            # If not, just fall back to min cases
            # pred_mask has shape [N, num_vocab]
            sampled_mask = tf.equal(mask_range, tf.reshape(sampled_token, [-1, 1]))
            is_sampled_valid = tf.reduce_any(
              tf.logical_and(sampled_mask, token_validity),
              axis=1)

            # Fall back to max score (no sampling)
            min_score = tf.reduce_min(token_scores)
            token_scores_valid = tf.where(token_validity, token_scores,
                           tf.ones_like(token_scores)*(min_score-1))
            max_score_token = tf.cast(tf.argmax(token_scores_valid, 1), tf.int32)
            predicted_token = tf.where(is_sampled_valid, sampled_token, max_score_token)
          else:
            min_score = tf.reduce_min(token_scores)
            token_scores_valid = tf.where(token_validity, token_scores,
                           tf.ones_like(token_scores)*(min_score-1))
            # predicted_token has shape [N]
            predicted_token = tf.cast(tf.argmax(token_scores_valid, 1), tf.int32)
          if use_gt_layout is not None:
            predicted_token = (gt_layout_batch[time-1] * gt_layout_mult
                     + predicted_token * pred_layout_mult)

          # a robust version of softmax
          # all_token_probs has shape [N, num_vocab]
          all_token_probs = tf.nn.softmax(token_scores) * validity_mult
          # tf.check_numerics(all_token_probs, 'NaN/Inf before div')
          all_token_probs = all_token_probs / tf.reduce_sum(all_token_probs + 1e-10, axis=1, keep_dims=True)
          # tf.check_numerics(all_token_probs, 'NaN/Inf after div')

          # mask has shape [N, num_vocab]
          mask = tf.equal(mask_range, tf.reshape(predicted_token, [-1, 1]))
          # token_prob has shape [N], the probability of the predicted token
          # although token_prob is not needed for predicting the next token
          # it is needed in output (for policy gradient training)
          # [N, num_vocab]
          token_prob = tf.reduce_sum(all_token_probs * tf.cast(mask, tf.float32), axis=1)
          # tf.assert_positive(token_prob)
          neg_entropy = tf.reduce_sum(
            all_token_probs * tf.log(all_token_probs + (1-validity_mult) + 1e-10),
            axis=1)

          # update states
          updated_decoding_state = _update_decoding_state(
            decoding_state, predicted_token, self.P)

          # the prediction is from the cell output of the last step
          # timestep (t-1), feed it as input into timestep t
          next_input = tf.nn.embedding_lookup(embedding_mat, predicted_token)

        elements_finished = tf.greater_equal(time, T_max)

        # loop_state is a 5-tuple, representing
        #   1) the predicted_tokens
        #   2) the prob of predicted_tokens
        #   3) the decoding state (used for validity)
        #   4) the negative entropy of policy (accumulated across timesteps)
        #   5) the attention
        if loop_state is None:  # time == 0
          # Write the predicted token into the output
          predicted_token_array = tf.TensorArray(dtype=tf.int32, size=T_max,
            infer_shape=False)
          token_prob_array = tf.TensorArray(dtype=tf.float32, size=T_max,
            infer_shape=False)
          init_decoding_state = tf.tile(to_T([[0, 0, T_max]], dtype=tf.int32), to_T([N, 1]))
          att_array = tf.TensorArray(dtype=tf.float32, size=T_max,
            infer_shape=False)
          next_loop_state = (predicted_token_array,
                   token_prob_array,
                   init_decoding_state,
                   tf.zeros(to_T([N]), dtype=tf.float32),
                   att_array)
        else:  # time > 0
          t_write = time-1
          next_loop_state = (loop_state[0].write(t_write, predicted_token),
                   loop_state[1].write(t_write, token_prob),
                   updated_decoding_state,
                   loop_state[3] + neg_entropy,
                   loop_state[4].write(t_write, att))
        return (elements_finished, next_input, next_cell_state, cell_output,
            next_loop_state)

      # The RNN
      cell = _get_lstm_cell(num_layers, lstm_dim, apply_dropout)
      _, _, decodes_ta = tf.nn.raw_rnn(cell, loop_fn, scope='lstm')
      predicted_tokens = decodes_ta[0].stack()
      token_probs = decodes_ta[1].stack()
      neg_entropy = decodes_ta[3]
      # atts has shape [T_decoder, T_encoder, N, 1]
      atts = decodes_ta[4].stack()
      # static dimension recast
      atts = tf.reshape(atts, [self.T_decoder, self.T_encoder, -1, 1])
      self.atts = atts
      # word_vec has shape [T_decoder, N, 1]
      word_vecs = tf.reduce_sum(atts*self.embedded_input_seq, axis=1)

      predicted_tokens.set_shape([None, None])
      token_probs.set_shape([None, None])
      neg_entropy.set_shape([None])
      #word_vecs.set_shape([None, None, self.encoder_embed_dim])
      # static shapes
      word_vecs.set_shape([self.T_decoder, None, self.encoder_embed_dim])

      self.predicted_tokens = predicted_tokens
      self.token_probs = token_probs
      self.neg_entropy = neg_entropy
      self.word_vecs = word_vecs