def __extend_minibatch_by_expansion_graph_train_from_sample()

in Models/exprsynth/nagdecoder.py [0:0]


    def __extend_minibatch_by_expansion_graph_train_from_sample(self, batch_data: Dict[str, Any],
                                                                sample: Dict[str, Any]) -> None:
        this_sample_id = batch_data['samples_in_batch'] - 1  # Counter already incremented when we get called

        total_edge_types = len(self.__expansion_labeled_edge_types) + len(self.__expansion_unlabeled_edge_types)
        for (step_num, schedule_step) in enumerate(sample['eg_schedule']):
            eg_node_id_to_step_target_id = OrderedDict()
            for edge_type in range(total_edge_types):
                for (source, target) in schedule_step[edge_type]:
                    batch_data['eg_sending_node_ids'][step_num][edge_type].append(source + batch_data['eg_node_offset'])
                    step_target_id = eg_node_id_to_step_target_id.get(target)
                    if step_target_id is None:
                        step_target_id = batch_data['next_step_target_node_id'][step_num]
                        batch_data['next_step_target_node_id'][step_num] += 1
                        eg_node_id_to_step_target_id[target] = step_target_id
                    batch_data['eg_msg_target_node_ids'][step_num][edge_type].append(step_target_id)
            for edge_type in range(len(self.__expansion_labeled_edge_types)):
                batch_data['eg_edge_label_ids'][step_num][edge_type].extend(sample['eg_edge_label_ids'][step_num][edge_type])
            for eg_target_node_id in eg_node_id_to_step_target_id.keys():
                batch_data['eg_receiving_node_ids'][step_num].append(eg_target_node_id + batch_data['eg_node_offset'])
            batch_data['eg_receiving_node_nums'][step_num] += len(eg_node_id_to_step_target_id)

        # ----- Data related to the production choices:
        batch_data['eg_production_nodes'].extend(sample['eg_node_id_to_prod_id'][:, 0] + batch_data['eg_node_offset'])
        batch_data['eg_production_node_choices'].extend(sample['eg_node_id_to_prod_id'][:, 1])
        if self.hyperparameters['eg_use_context_attention']:
            batch_data['eg_production_to_context_id'].extend([this_sample_id] * sample['eg_node_id_to_prod_id'].shape[0])

        if self.hyperparameters['eg_use_vars_for_production_choice']:
            for (prod_index, prod_node_id) in enumerate(sample['eg_node_id_to_prod_id'][:, 0]):
                var_last_uses_at_prod_node_id = sample['eg_production_node_id_to_var_last_use_node_ids'][prod_node_id]
                batch_data['eg_production_var_last_use_node_ids'].extend(var_last_uses_at_prod_node_id + batch_data['eg_node_offset'])
                overall_prod_index = prod_index + batch_data['eg_prod_idx_offset']
                batch_data['eg_production_var_last_use_node_ids_target_ids'].extend([overall_prod_index] * len(var_last_uses_at_prod_node_id))

        for (eg_varproduction_node_id, eg_varproduction_options_node_ids, chosen_id) in sample['eg_node_id_to_varchoice']:
            batch_data['eg_varproduction_nodes'].append(eg_varproduction_node_id + batch_data['eg_node_offset'])

            # Restrict to number of choices that we want to allow, make sure we keep the correct one:
            eg_varproduction_correct_node_id = eg_varproduction_options_node_ids[chosen_id]
            eg_varproduction_distractor_node_ids = eg_varproduction_options_node_ids[:chosen_id] + eg_varproduction_options_node_ids[chosen_id + 1:]
            np.random.shuffle(eg_varproduction_distractor_node_ids)
            eg_varproduction_options_node_ids = [eg_varproduction_correct_node_id]
            eg_varproduction_options_node_ids.extend(eg_varproduction_distractor_node_ids[:self.hyperparameters['eg_max_variable_choices'] - 1])
            num_of_options = len(eg_varproduction_options_node_ids)
            if num_of_options == 0:
                raise Exception("Sample is choosing a variable from an empty set.")
            num_padding = self.hyperparameters['eg_max_variable_choices'] - num_of_options
            eg_varproduction_options_mask = [1.] * num_of_options + [0.] * num_padding
            eg_varproduction_options_node_ids = np.array(eg_varproduction_options_node_ids + [0] * num_padding)
            batch_data['eg_varproduction_options_nodes'].append(eg_varproduction_options_node_ids + batch_data['eg_node_offset'])
            batch_data['eg_varproduction_options_mask'].append(eg_varproduction_options_mask)
            batch_data['eg_varproduction_node_choices'].append(0)  # We've reordered so that the correct choice is always first

        for literal_kind in LITERAL_NONTERMINALS:
            # Shape [num_choice_nodes, 2], with (v, c) meaning that at eg node v, we want to choose literal c:
            literal_choices = sample['eg_node_id_to_literal_choice'][literal_kind]

            if self.hyperparameters['eg_use_literal_copying']:
                # Prepare normalizer. We'll use an unsorted_segment_sum on the model side, and that only operates on a flattened shape
                # So here, we repeat the normalizer an appropriate number of times, but shifting by the number of choices
                normalizer_map = sample['eg_literal_choice_normalizer_maps'][literal_kind]
                num_choices_so_far = sum(choice_nodes.shape[0] for choice_nodes in batch_data['eg_litproduction_nodes'][literal_kind])
                num_choices_this_sample = literal_choices.shape[0]
                repeated_normalizer_map = np.tile(np.expand_dims(normalizer_map, axis=0),
                                                  reps=[num_choices_this_sample, 1])
                flattened_normalizer_offsets = np.repeat((np.arange(num_choices_this_sample) + num_choices_so_far) * len(normalizer_map),
                                                         repeats=len(normalizer_map))
                normalizer_offsets = np.reshape(flattened_normalizer_offsets, [-1, len(normalizer_map)])
                batch_data['eg_litproduction_choice_normalizer'][literal_kind].append(
                    np.reshape(repeated_normalizer_map + normalizer_offsets, -1))

                batch_data['eg_litproduction_to_context_id'][literal_kind].append([this_sample_id] * literal_choices.shape[0])

            batch_data['eg_litproduction_nodes'][literal_kind].append(literal_choices[:, 0] + batch_data['eg_node_offset'])
            batch_data['eg_litproduction_node_choices'][literal_kind].append(literal_choices[:, 1])