in Models/exprsynth/nagdecoder.py [0:0]
def __extend_minibatch_by_expansion_graph_train_from_sample(self, batch_data: Dict[str, Any],
sample: Dict[str, Any]) -> None:
this_sample_id = batch_data['samples_in_batch'] - 1 # Counter already incremented when we get called
total_edge_types = len(self.__expansion_labeled_edge_types) + len(self.__expansion_unlabeled_edge_types)
for (step_num, schedule_step) in enumerate(sample['eg_schedule']):
eg_node_id_to_step_target_id = OrderedDict()
for edge_type in range(total_edge_types):
for (source, target) in schedule_step[edge_type]:
batch_data['eg_sending_node_ids'][step_num][edge_type].append(source + batch_data['eg_node_offset'])
step_target_id = eg_node_id_to_step_target_id.get(target)
if step_target_id is None:
step_target_id = batch_data['next_step_target_node_id'][step_num]
batch_data['next_step_target_node_id'][step_num] += 1
eg_node_id_to_step_target_id[target] = step_target_id
batch_data['eg_msg_target_node_ids'][step_num][edge_type].append(step_target_id)
for edge_type in range(len(self.__expansion_labeled_edge_types)):
batch_data['eg_edge_label_ids'][step_num][edge_type].extend(sample['eg_edge_label_ids'][step_num][edge_type])
for eg_target_node_id in eg_node_id_to_step_target_id.keys():
batch_data['eg_receiving_node_ids'][step_num].append(eg_target_node_id + batch_data['eg_node_offset'])
batch_data['eg_receiving_node_nums'][step_num] += len(eg_node_id_to_step_target_id)
# ----- Data related to the production choices:
batch_data['eg_production_nodes'].extend(sample['eg_node_id_to_prod_id'][:, 0] + batch_data['eg_node_offset'])
batch_data['eg_production_node_choices'].extend(sample['eg_node_id_to_prod_id'][:, 1])
if self.hyperparameters['eg_use_context_attention']:
batch_data['eg_production_to_context_id'].extend([this_sample_id] * sample['eg_node_id_to_prod_id'].shape[0])
if self.hyperparameters['eg_use_vars_for_production_choice']:
for (prod_index, prod_node_id) in enumerate(sample['eg_node_id_to_prod_id'][:, 0]):
var_last_uses_at_prod_node_id = sample['eg_production_node_id_to_var_last_use_node_ids'][prod_node_id]
batch_data['eg_production_var_last_use_node_ids'].extend(var_last_uses_at_prod_node_id + batch_data['eg_node_offset'])
overall_prod_index = prod_index + batch_data['eg_prod_idx_offset']
batch_data['eg_production_var_last_use_node_ids_target_ids'].extend([overall_prod_index] * len(var_last_uses_at_prod_node_id))
for (eg_varproduction_node_id, eg_varproduction_options_node_ids, chosen_id) in sample['eg_node_id_to_varchoice']:
batch_data['eg_varproduction_nodes'].append(eg_varproduction_node_id + batch_data['eg_node_offset'])
# Restrict to number of choices that we want to allow, make sure we keep the correct one:
eg_varproduction_correct_node_id = eg_varproduction_options_node_ids[chosen_id]
eg_varproduction_distractor_node_ids = eg_varproduction_options_node_ids[:chosen_id] + eg_varproduction_options_node_ids[chosen_id + 1:]
np.random.shuffle(eg_varproduction_distractor_node_ids)
eg_varproduction_options_node_ids = [eg_varproduction_correct_node_id]
eg_varproduction_options_node_ids.extend(eg_varproduction_distractor_node_ids[:self.hyperparameters['eg_max_variable_choices'] - 1])
num_of_options = len(eg_varproduction_options_node_ids)
if num_of_options == 0:
raise Exception("Sample is choosing a variable from an empty set.")
num_padding = self.hyperparameters['eg_max_variable_choices'] - num_of_options
eg_varproduction_options_mask = [1.] * num_of_options + [0.] * num_padding
eg_varproduction_options_node_ids = np.array(eg_varproduction_options_node_ids + [0] * num_padding)
batch_data['eg_varproduction_options_nodes'].append(eg_varproduction_options_node_ids + batch_data['eg_node_offset'])
batch_data['eg_varproduction_options_mask'].append(eg_varproduction_options_mask)
batch_data['eg_varproduction_node_choices'].append(0) # We've reordered so that the correct choice is always first
for literal_kind in LITERAL_NONTERMINALS:
# Shape [num_choice_nodes, 2], with (v, c) meaning that at eg node v, we want to choose literal c:
literal_choices = sample['eg_node_id_to_literal_choice'][literal_kind]
if self.hyperparameters['eg_use_literal_copying']:
# Prepare normalizer. We'll use an unsorted_segment_sum on the model side, and that only operates on a flattened shape
# So here, we repeat the normalizer an appropriate number of times, but shifting by the number of choices
normalizer_map = sample['eg_literal_choice_normalizer_maps'][literal_kind]
num_choices_so_far = sum(choice_nodes.shape[0] for choice_nodes in batch_data['eg_litproduction_nodes'][literal_kind])
num_choices_this_sample = literal_choices.shape[0]
repeated_normalizer_map = np.tile(np.expand_dims(normalizer_map, axis=0),
reps=[num_choices_this_sample, 1])
flattened_normalizer_offsets = np.repeat((np.arange(num_choices_this_sample) + num_choices_so_far) * len(normalizer_map),
repeats=len(normalizer_map))
normalizer_offsets = np.reshape(flattened_normalizer_offsets, [-1, len(normalizer_map)])
batch_data['eg_litproduction_choice_normalizer'][literal_kind].append(
np.reshape(repeated_normalizer_map + normalizer_offsets, -1))
batch_data['eg_litproduction_to_context_id'][literal_kind].append([this_sample_id] * literal_choices.shape[0])
batch_data['eg_litproduction_nodes'][literal_kind].append(literal_choices[:, 0] + batch_data['eg_node_offset'])
batch_data['eg_litproduction_node_choices'][literal_kind].append(literal_choices[:, 1])