in Models/exprsynth/nagdecoder.py [0:0]
def __make_train_model(self):
# Pick CG representation where possible, and use embedding otherwise:
eg_node_label_embeddings = \
tf.nn.embedding_lookup(self.parameters['eg_token_embeddings'],
self.placeholders['eg_node_token_ids'])
eg_initial_node_representations = \
tf.where(condition=self.ops['eg_node_representation_use_from_context'],
x=self.ops['eg_node_representations_from_context'],
y=eg_node_label_embeddings)
# ----- (3) Compute representations of expansion graph using an async GNN submodel:
eg_h_dim = self.hyperparameters['eg_hidden_size']
eg_hypers = {name.replace("eg_", "", 1): value
for (name, value) in self.hyperparameters.items()
if name.startswith("eg_")}
eg_hypers['propagation_rounds'] = 1
eg_hypers['num_labeled_edge_types'] = len(self.__expansion_labeled_edge_types)
eg_hypers['num_unlabeled_edge_types'] = len(self.__expansion_unlabeled_edge_types)
with tf.variable_scope("ExpansionGraph"):
eg_model = AsyncGGNN(eg_hypers)
# Note that we only use a single async schedule here, so every argument is wrapped in
# [] to use the generic code supporting many schedules:
eg_node_representations = \
eg_model.async_ggnn_layer(
eg_initial_node_representations,
[self.placeholders['eg_initial_node_ids']],
[self.placeholders['eg_sending_node_ids']],
[self.__embed_edge_labels(self.hyperparameters['eg_propagation_substeps'])],
[self.placeholders['eg_msg_target_node_ids']],
[self.placeholders['eg_receiving_node_ids']],
[self.placeholders['eg_receiving_node_nums']])
# ----- (4) Finally, try to predict the right productions:
# === Grammar productions:
eg_production_node_representations = \
tf.gather(params=eg_node_representations,
indices=self.placeholders['eg_production_nodes']) # Shape [num_choice_nodes, D]
if self.hyperparameters['eg_use_vars_for_production_choice']:
variable_representations_at_prod_choice = \
tf.gather(params=eg_node_representations,
indices=self.placeholders['eg_production_var_last_use_node_ids'])
variable_representations_at_prod_choice = \
tf.unsorted_segment_mean(
data=variable_representations_at_prod_choice,
segment_ids=self.placeholders['eg_production_var_last_use_node_ids_target_ids'],
num_segments=tf.shape(eg_production_node_representations)[0])
else:
variable_representations_at_prod_choice = None
eg_production_choice_logits = \
self.__make_production_choice_logits_model(
eg_production_node_representations,
variable_representations_at_prod_choice,
self.ops.get('context_token_representations'),
self.placeholders.get('context_token_mask'),
self.placeholders.get('eg_production_to_context_id'))
# === Variable productions
eg_varproduction_node_representations = \
tf.gather(params=eg_node_representations,
indices=self.placeholders['eg_varproduction_nodes']) # Shape: [VP, D]
eg_varproduction_options_nodes_flat = \
tf.reshape(self.placeholders['eg_varproduction_options_nodes'],
shape=[-1]) # Shape [VP * eg_max_variable_choices]
eg_varproduction_options_representations = \
tf.reshape(tf.gather(params=eg_node_representations,
indices=eg_varproduction_options_nodes_flat
), # Shape: [VP * eg_max_variable_choices, D]
shape=[-1, self.hyperparameters['eg_max_variable_choices'], eg_h_dim]
) # Shape: [VP, eg_max_variable_choices, D]
eg_varproduction_choice_logits = \
self.__make_variable_choice_logits_model(
eg_varproduction_node_representations,
eg_varproduction_options_representations,
) # Shape: [VP, eg_max_variable_choices]
# Mask out unused choice options out:
eg_varproduction_choice_logits += \
(1.0 - self.placeholders['eg_varproduction_options_mask']) * -BIG_NUMBER
# === Literal productions
literal_logits = {}
for literal_kind in LITERAL_NONTERMINALS:
eg_litproduction_representation = \
tf.gather(params=eg_node_representations,
indices=self.placeholders['eg_litproduction_nodes'][literal_kind]
) # Shape: [LP, D]
eg_litproduction_to_context_id, eg_litproduction_choice_normalizer = None, None
if self.hyperparameters['eg_use_literal_copying']:
eg_litproduction_to_context_id = \
self.placeholders['eg_litproduction_to_context_id'][literal_kind]
eg_litproduction_choice_normalizer = \
self.placeholders['eg_litproduction_choice_normalizer'][literal_kind]
literal_logits[literal_kind] = \
self.__make_literal_choice_logits_model(
literal_kind,
eg_litproduction_representation,
self.ops.get('context_token_representations'),
self.placeholders.get('context_token_mask'),
eg_litproduction_to_context_id,
eg_litproduction_choice_normalizer,
)
# (5) Compute loss:
raw_prod_loss = \
tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=eg_production_choice_logits,
labels=self.placeholders['eg_production_node_choices'])
raw_var_loss = \
tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=eg_varproduction_choice_logits,
labels=self.placeholders['eg_varproduction_node_choices'])
# Normalize all losses by number of actual decisions made, which differ from batch to batch.
# Can't use tf.reduce_mean because these can be empty, and reduce_mean gives NaN for those:
prod_loss = tf.reduce_sum(raw_prod_loss) / (tf.cast(tf.size(raw_prod_loss), dtype=tf.float32) + SMALL_NUMBER)
var_loss = tf.reduce_sum(raw_var_loss) / (tf.cast(tf.size(raw_var_loss), dtype=tf.float32) + SMALL_NUMBER)
if len(LITERAL_NONTERMINALS) > 0:
raw_lit_loss = [tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=literal_logits[literal_kind],
labels=self.placeholders['eg_litproduction_node_choices'][literal_kind])
for literal_kind in LITERAL_NONTERMINALS]
raw_lit_loss = tf.concat(raw_lit_loss, axis=0)
lit_loss = tf.reduce_sum(raw_lit_loss) / (tf.cast(tf.size(raw_lit_loss), dtype=tf.float32) + SMALL_NUMBER)
else:
raw_lit_loss = [0.0]
lit_loss = 0.0
self.ops['loss'] = prod_loss + var_loss + lit_loss
# TODO: If we want to batch this per sample, then we need an extra placeholder that maps productions/variables to
# samples and use unsorted_segment_sum to gather together all the logprobs from all productions.
self.ops['log_probs'] = -tf.reduce_sum(raw_prod_loss) - tf.reduce_sum(raw_var_loss) - tf.reduce_sum(raw_lit_loss)