def __make_train_model()

in Models/exprsynth/nagdecoder.py [0:0]


    def __make_train_model(self):
        # Pick CG representation where possible, and use embedding otherwise:
        eg_node_label_embeddings = \
            tf.nn.embedding_lookup(self.parameters['eg_token_embeddings'],
                                   self.placeholders['eg_node_token_ids'])
        eg_initial_node_representations = \
            tf.where(condition=self.ops['eg_node_representation_use_from_context'],
                     x=self.ops['eg_node_representations_from_context'],
                     y=eg_node_label_embeddings)

        # ----- (3) Compute representations of expansion graph using an async GNN submodel:
        eg_h_dim = self.hyperparameters['eg_hidden_size']
        eg_hypers = {name.replace("eg_", "", 1): value
                     for (name, value) in self.hyperparameters.items()
                     if name.startswith("eg_")}
        eg_hypers['propagation_rounds'] = 1
        eg_hypers['num_labeled_edge_types'] = len(self.__expansion_labeled_edge_types)
        eg_hypers['num_unlabeled_edge_types'] = len(self.__expansion_unlabeled_edge_types)
        with tf.variable_scope("ExpansionGraph"):
            eg_model = AsyncGGNN(eg_hypers)

            # Note that we only use a single async schedule here, so every argument is wrapped in
            # [] to use the generic code supporting many schedules:
            eg_node_representations = \
                eg_model.async_ggnn_layer(
                    eg_initial_node_representations,
                    [self.placeholders['eg_initial_node_ids']],
                    [self.placeholders['eg_sending_node_ids']],
                    [self.__embed_edge_labels(self.hyperparameters['eg_propagation_substeps'])],
                    [self.placeholders['eg_msg_target_node_ids']],
                    [self.placeholders['eg_receiving_node_ids']],
                    [self.placeholders['eg_receiving_node_nums']])

        # ----- (4) Finally, try to predict the right productions:
        # === Grammar productions:
        eg_production_node_representations = \
            tf.gather(params=eg_node_representations,
                      indices=self.placeholders['eg_production_nodes'])  # Shape [num_choice_nodes, D]

        if self.hyperparameters['eg_use_vars_for_production_choice']:
            variable_representations_at_prod_choice = \
                tf.gather(params=eg_node_representations,
                          indices=self.placeholders['eg_production_var_last_use_node_ids'])
            variable_representations_at_prod_choice = \
                tf.unsorted_segment_mean(
                    data=variable_representations_at_prod_choice,
                    segment_ids=self.placeholders['eg_production_var_last_use_node_ids_target_ids'],
                    num_segments=tf.shape(eg_production_node_representations)[0])
        else:
            variable_representations_at_prod_choice = None

        eg_production_choice_logits = \
            self.__make_production_choice_logits_model(
                eg_production_node_representations,
                variable_representations_at_prod_choice,
                self.ops.get('context_token_representations'),
                self.placeholders.get('context_token_mask'),
                self.placeholders.get('eg_production_to_context_id'))

        # === Variable productions
        eg_varproduction_node_representations = \
            tf.gather(params=eg_node_representations,
                      indices=self.placeholders['eg_varproduction_nodes'])  # Shape: [VP, D]
        eg_varproduction_options_nodes_flat = \
            tf.reshape(self.placeholders['eg_varproduction_options_nodes'],
                       shape=[-1])  # Shape [VP * eg_max_variable_choices]
        eg_varproduction_options_representations = \
            tf.reshape(tf.gather(params=eg_node_representations,
                                 indices=eg_varproduction_options_nodes_flat
                                 ),  # Shape: [VP * eg_max_variable_choices, D]
                       shape=[-1, self.hyperparameters['eg_max_variable_choices'], eg_h_dim]
                       )  # Shape: [VP, eg_max_variable_choices, D]
        eg_varproduction_choice_logits = \
            self.__make_variable_choice_logits_model(
                eg_varproduction_node_representations,
                eg_varproduction_options_representations,
                )  # Shape: [VP, eg_max_variable_choices]
        # Mask out unused choice options out:
        eg_varproduction_choice_logits += \
            (1.0 - self.placeholders['eg_varproduction_options_mask']) * -BIG_NUMBER

        # === Literal productions
        literal_logits = {}
        for literal_kind in LITERAL_NONTERMINALS:
            eg_litproduction_representation = \
                tf.gather(params=eg_node_representations,
                          indices=self.placeholders['eg_litproduction_nodes'][literal_kind]
                          )  # Shape: [LP, D]
            eg_litproduction_to_context_id, eg_litproduction_choice_normalizer = None, None
            if self.hyperparameters['eg_use_literal_copying']:
                eg_litproduction_to_context_id = \
                    self.placeholders['eg_litproduction_to_context_id'][literal_kind]
                eg_litproduction_choice_normalizer = \
                    self.placeholders['eg_litproduction_choice_normalizer'][literal_kind]

            literal_logits[literal_kind] = \
                self.__make_literal_choice_logits_model(
                    literal_kind,
                    eg_litproduction_representation,
                    self.ops.get('context_token_representations'),
                    self.placeholders.get('context_token_mask'),
                    eg_litproduction_to_context_id,
                    eg_litproduction_choice_normalizer,
                    )

        # (5) Compute loss:
        raw_prod_loss = \
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=eg_production_choice_logits,
                labels=self.placeholders['eg_production_node_choices'])
        raw_var_loss = \
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                logits=eg_varproduction_choice_logits,
                labels=self.placeholders['eg_varproduction_node_choices'])

        # Normalize all losses by number of actual decisions made, which differ from batch to batch.
        # Can't use tf.reduce_mean because these can be empty, and reduce_mean gives NaN for those:
        prod_loss = tf.reduce_sum(raw_prod_loss) / (tf.cast(tf.size(raw_prod_loss), dtype=tf.float32) + SMALL_NUMBER)
        var_loss = tf.reduce_sum(raw_var_loss) / (tf.cast(tf.size(raw_var_loss), dtype=tf.float32) + SMALL_NUMBER)
        if len(LITERAL_NONTERMINALS) > 0:
            raw_lit_loss = [tf.nn.sparse_softmax_cross_entropy_with_logits(
                                logits=literal_logits[literal_kind],
                                labels=self.placeholders['eg_litproduction_node_choices'][literal_kind])
                            for literal_kind in LITERAL_NONTERMINALS]
            raw_lit_loss = tf.concat(raw_lit_loss, axis=0)
            lit_loss = tf.reduce_sum(raw_lit_loss) / (tf.cast(tf.size(raw_lit_loss), dtype=tf.float32) + SMALL_NUMBER)
        else:
            raw_lit_loss = [0.0]
            lit_loss = 0.0

        self.ops['loss'] = prod_loss + var_loss + lit_loss

        # TODO: If we want to batch this per sample, then we need an extra placeholder that maps productions/variables to
        # samples and use unsorted_segment_sum to gather together all the logprobs from all productions.
        self.ops['log_probs'] = -tf.reduce_sum(raw_prod_loss) - tf.reduce_sum(raw_var_loss) - tf.reduce_sum(raw_lit_loss)