def __make_train_model()

in Models/exprsynth/seqdecoder.py [0:0]


    def __make_train_model(self):
        rnn_cell = make_rnn_cell(self.hyperparameters['decoder_rnn_layer_num'],
                                 self.hyperparameters['decoder_rnn_cell_type'],
                                 hidden_size=self.hyperparameters['decoder_rnn_hidden_size'],
                                 dropout_keep_rate=self.placeholders['dropout_keep_rate'],
                                 )
        initial_cell_state = self.__make_decoder_rnn_initial_state(self.ops['decoder_initial_state'], rnn_cell)

        # Reorg data from [batch, time, emb_dim] to [time, batch, emb_dim], and build corresponding tensor array:
        target_tokens_by_time = tf.transpose(self.placeholders['target_token_ids'], perm=[1, 0])
        target_tokens_ta = tf.TensorArray(dtype=tf.int32,
                                          size=self.hyperparameters['decoder_max_target_length'],
                                          name="target_tokens_embedded_ta",
                                          )
        target_tokens_ta = target_tokens_ta.unstack(target_tokens_by_time)

        # First, initialise loop variables:
        one_one_per_sample = tf.ones_like(self.placeholders['target_token_ids'][:,0])
        initial_input = one_one_per_sample * self.metadata['decoder_token_vocab'].get_id_or_unk(START_TOKEN)
        initial_input = tf.nn.embedding_lookup(self.parameters['decoder_token_embedding'], initial_input)
        end_token = one_one_per_sample * self.metadata['decoder_token_vocab'].get_id_or_unk(END_TOKEN)
        empty_output_logits_ta = tf.TensorArray(dtype=tf.float32,
                                                size=self.hyperparameters['decoder_max_target_length'],
                                                name="output_logits_ta",
                                                )

        def condition(time_unused, output_logits_ta_unused, decoder_state_unused, last_output_tok_embedded_unused, finished):
            return tf.logical_not(tf.reduce_all(finished))

        def body(step, output_logits_ta, decoder_state, last_output_tok_embedded, finished):
            next_step = step + 1

            # Use the RNN to decode one more tok:
            cur_output, next_decoder_state = rnn_cell(last_output_tok_embedded, decoder_state)
            cur_rnn_output_logits = self.parameters['decoder_output_projection'](cur_output)

            # Decide if we're done everywhere:
            next_finished = tf.logical_or(finished, next_step >= self.hyperparameters['decoder_max_target_length'])

            # Decide next token: If in training, use the next target token...
            all_next_finished = tf.reduce_all(next_finished)
            cur_output_tok = tf.cond(all_next_finished,
                                     lambda: end_token,
                                     lambda: target_tokens_ta.read(step))

            cur_output_tok_embedded = tf.nn.embedding_lookup(self.parameters['decoder_token_embedding'],
                                                             cur_output_tok)

            # Write out the collected wisdom:
            output_logits_ta = output_logits_ta.write(step, cur_rnn_output_logits)
            return (next_step, output_logits_ta, next_decoder_state, cur_output_tok_embedded, next_finished)

        (_, final_output_logits_ta, _, _, _) = \
            tf.while_loop(condition,
                          body,
                          loop_vars=[tf.constant(0, dtype=tf.int32),
                                     empty_output_logits_ta,
                                     initial_cell_state,
                                     initial_input,
                                     tf.zeros_like(self.placeholders['target_token_ids'][:,0], dtype=tf.bool),
                                     ],
                          parallel_iterations=1
                          )

        output_logits_by_time = final_output_logits_ta.stack()
        self.ops['decoder_output_logits'] = tf.transpose(output_logits_by_time, perm=[1, 0, 2])
        self.ops['decoder_output_probs'] = tf.nn.softmax(self.ops['decoder_output_logits'])

        # Produce loss:
        outputs_correct_crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.placeholders['target_token_ids'],
                                                                                  logits=self.ops['decoder_output_logits'])
        masked_outputs_correct_crossent = outputs_correct_crossent * self.placeholders['target_token_ids_mask']


        decoder_loss = tf.reduce_sum(masked_outputs_correct_crossent)
        self.ops['log_probs'] = -decoder_loss

        # Normalize by batch size:
        self.ops['loss'] = decoder_loss / tf.to_float(self.placeholders['batch_size'])