def generate_suggestions_for_one_sample()

in Models/exprsynth/nagdecoder.py [0:0]


    def generate_suggestions_for_one_sample(self,
                                            test_sample: Dict[str, Any],
                                            initial_eg_node_representations: tf.Tensor,
                                            beam_size: int=3,
                                            max_decoding_steps: int=100,
                                            context_tokens: Optional[List[str]]=None,
                                            context_token_representations: Optional[tf.Tensor]=None,
                                            context_token_mask: Optional[np.ndarray]=None,
                                            ) -> ModelTestResult:
        production_id_to_production = {}  # type: Dict[int, Tuple[str, Iterable[str]]]
        for (nonterminal, nonterminal_rules) in self.metadata['eg_production_vocab'].items():
            for (expansion, prod_id) in nonterminal_rules.items():
                production_id_to_production[prod_id] = (nonterminal, expansion)

        max_used_eg_node_id = initial_eg_node_representations.shape[0]

        def declare_new_node(expansion_info: ExpansionInformation, parent_node: int, node_type: str) -> int:
            nonlocal max_used_eg_node_id
            new_node_id = max_used_eg_node_id
            new_synthesised_node_id = max_used_eg_node_id + 1
            max_used_eg_node_id += 2
            expansion_info.node_to_parent[new_node_id] = parent_node
            expansion_info.node_to_children[parent_node].append(new_node_id)
            expansion_info.node_to_type[new_node_id] = node_type
            expansion_info.node_to_label[new_node_id] = node_type
            expansion_info.node_to_label[new_synthesised_node_id] = node_type
            expansion_info.node_to_synthesised_attr_node[new_node_id] = new_synthesised_node_id
            expansion_info.node_to_inherited_attr_node[new_synthesised_node_id] = new_node_id
            return new_node_id

        def get_node_attributes(expansion_info: ExpansionInformation, node_id: int) -> tf.Tensor:
            """
            Return attributes associated with node, either from cache in expansion information or by
            calling the model to compute a representation according to the edge information in
            the expansion_info.
            """
            node_attributes = expansion_info.node_to_representation.get(node_id)
            if node_attributes is None:
                node_label = expansion_info.node_to_label[node_id]
                if node_label == VARIABLE_NONTERMINAL:
                    node_label = ROOT_NONTERMINAL
                if node_id not in expansion_info.node_to_unlabeled_incoming_edges:
                    self.compute_incoming_edges(self.metadata['eg_production_vocab'].keys(), expansion_info, node_id)
                msg_prop_data = {self.placeholders['eg_msg_target_label_id']: self.metadata['eg_token_vocab'].get_id_or_unk(node_label)}
                for labeled_edge_typ in self.__expansion_labeled_edge_types.keys():
                    source_node_ids = [v for (v, _) in expansion_info.node_to_labeled_incoming_edges[node_id][labeled_edge_typ]]
                    edge_labels = [l for (_, l) in expansion_info.node_to_labeled_incoming_edges[node_id][labeled_edge_typ]]
                    if len(source_node_ids) == 0:
                        sender_repr = np.empty(shape=[0, self.hyperparameters['eg_hidden_size']])
                    else:
                        sender_repr = [get_node_attributes(expansion_info, source_node_id) for source_node_id in source_node_ids]
                    labeled_edge_typ_idx = self.__expansion_labeled_edge_types[labeled_edge_typ]
                    msg_prop_data[self.placeholders['eg_msg_source_representations'][labeled_edge_typ_idx]] = sender_repr
                    msg_prop_data[self.placeholders['eg_edge_label_ids'][0][labeled_edge_typ_idx]] = edge_labels
                for unlabeled_edge_typ in self.__expansion_unlabeled_edge_types.keys():
                    source_node_ids = expansion_info.node_to_unlabeled_incoming_edges[node_id][unlabeled_edge_typ]
                    if len(source_node_ids) == 0:
                        sender_repr = np.empty(shape=[0, self.hyperparameters['eg_hidden_size']])
                    else:
                        sender_repr = [get_node_attributes(expansion_info, source_node_id) for source_node_id in source_node_ids]
                    shifted_edge_type_id = len(self.__expansion_labeled_edge_types) + self.__expansion_unlabeled_edge_types[unlabeled_edge_typ]
                    msg_prop_data[self.placeholders['eg_msg_source_representations'][shifted_edge_type_id]] = sender_repr
                # print("Computing attributes for %i (label %s) with following edges:" % (node_id, node_label))
                # for labeled_edge_type in EXPANSION_LABELED_EDGE_TYPE_NAMES + EXPANSION_UNLABELED_EDGE_TYPE_NAMES:
                #     edges = expansion_info.node_to_labeled_incoming_edges[node_id][labeled_edge_type]
                #     if len(edges) > 0:
                #         print(" %s edges: [%s]" % (labeled_edge_type,
                #                                    ", ".join("(%s, %s)" % (w, node_id) for w in edges)))

                node_attributes = self.sess.run(self.ops['eg_step_propagation_result'], feed_dict=msg_prop_data)
                expansion_info.node_to_representation[node_id] = node_attributes
            return node_attributes

        def sample_productions(expansion_info: ExpansionInformation, node_to_expand: int) -> List[Tuple[Iterable[str], int, float]]:
            prod_query_data = {}
            write_to_minibatch(prod_query_data,
                               self.placeholders['eg_production_node_representation'],
                               get_node_attributes(expansion_info, node_to_expand))
            if self.hyperparameters['eg_use_vars_for_production_choice']:
                vars_in_scope = list(expansion_info.variable_to_last_use_id.keys())
                vars_in_scope.remove(LAST_USED_TOKEN_NAME)
                vars_in_scope_representations = [get_node_attributes(expansion_info, expansion_info.variable_to_last_use_id[var])
                                                 for var in vars_in_scope]
                write_to_minibatch(prod_query_data,
                                   self.placeholders['eg_production_var_representations'],
                                   vars_in_scope_representations)
            if self.hyperparameters['eg_use_literal_copying'] or self.hyperparameters['eg_use_context_attention']:
                write_to_minibatch(prod_query_data,
                                   self.placeholders['context_token_representations'],
                                   expansion_info.context_token_representations)
                write_to_minibatch(prod_query_data,
                                   self.placeholders['context_token_mask'],
                                   expansion_info.context_token_mask)
            production_probs = self.sess.run(self.ops['eg_production_choice_probs'], feed_dict=prod_query_data)
            result = []
            # print("### Prod probs: %s" % (str(production_probs),))
            for picked_production_index in pick_indices_from_probs(production_probs, beam_size):
                prod_lhs, prod_rhs = production_id_to_production[picked_production_index]
                # TODO: This should be ensured by appropriate masking in the model
                if prod_lhs == expansion_info.node_to_type[node_to_expand]:
                    assert prod_lhs == expansion_info.node_to_type[node_to_expand]
                    result.append((prod_rhs, picked_production_index, production_probs[picked_production_index]))

            return result

        def sample_variable(expansion_info: ExpansionInformation, node_id: int) -> List[Tuple[str, float]]:
            vars_in_scope = list(expansion_info.variable_to_last_use_id.keys())
            vars_in_scope.remove(LAST_USED_TOKEN_NAME)
            vars_in_scope_representations = [get_node_attributes(expansion_info, expansion_info.variable_to_last_use_id[var])
                                             for var in vars_in_scope]
            var_query_data = {self.placeholders['eg_num_variable_choices']: len(vars_in_scope)}
            write_to_minibatch(var_query_data, self.placeholders['eg_varproduction_options_representations'], vars_in_scope_representations)
            # We choose the variable name based on the information of the /parent/ node:
            parent_node = expansion_info.node_to_parent[node_id]
            write_to_minibatch(var_query_data, self.placeholders['eg_varproduction_node_representation'], get_node_attributes(expansion_info, parent_node))
            var_probs = self.sess.run(self.ops['eg_varproduction_choice_probs'], feed_dict=var_query_data)

            result = []
            # print("### Var probs: %s" % (str(var_probs),))
            for picked_var_index in pick_indices_from_probs(var_probs, beam_size):
                result.append((vars_in_scope[picked_var_index], var_probs[picked_var_index]))
            return result

        def sample_literal(expansion_info: ExpansionInformation, node_id: int) -> List[Tuple[str, float]]:
            literal_kind_to_sample = expansion_info.node_to_type[node_id]
            lit_query_data = {}

            # We choose the literal based on the information of the /parent/ node:
            parent_node = expansion_info.node_to_parent[node_id]
            write_to_minibatch(lit_query_data, self.placeholders['eg_litproduction_node_representation'], get_node_attributes(expansion_info, parent_node))

            if self.hyperparameters["eg_use_literal_copying"]:
                write_to_minibatch(lit_query_data,
                                   self.placeholders['context_token_representations'],
                                   expansion_info.context_token_representations)
                write_to_minibatch(lit_query_data,
                                   self.placeholders['context_token_mask'],
                                   expansion_info.context_token_mask)
                write_to_minibatch(lit_query_data,
                                   self.placeholders['eg_litproduction_choice_normalizer'],
                                   expansion_info.literal_production_choice_normalizer[literal_kind_to_sample])
            lit_probs = self.sess.run(self.ops['eg_litproduction_choice_probs'][literal_kind_to_sample],
                                      feed_dict=lit_query_data)

            result = []
            # print("### Var probs: %s" % (str(lit_probs),))
            literal_vocab = self.metadata["eg_literal_vocabs"][literal_kind_to_sample]
            literal_vocab_size = len(literal_vocab)
            for picked_lit_index in pick_indices_from_probs(lit_probs, beam_size):
                if picked_lit_index < literal_vocab_size:
                    result.append((literal_vocab.id_to_token[picked_lit_index], lit_probs[picked_lit_index]))
                else:
                    result.append((expansion_info.context_tokens[picked_lit_index - literal_vocab_size], lit_probs[picked_lit_index]))
            return result

        def expand_node(expansion_info: ExpansionInformation) -> List[ExpansionInformation]:
            if len(expansion_info.nodes_to_expand) == 0:
                return [expansion_info]
            if expansion_info.num_expansions > max_decoding_steps:
                return []

            node_to_expand = expansion_info.nodes_to_expand.popleft()
            type_to_expand = expansion_info.node_to_type[node_to_expand]
            expansions = []

            if type_to_expand in self.metadata['eg_production_vocab']:
                # Case production from grammar
                for (prod_rhs, prod_id, prod_probability) in sample_productions(expansion_info, node_to_expand):
                    picked_rhs_expansion_info = clone_expansion_info(expansion_info, increment_expansion_counter=True)
                    picked_rhs_expansion_info.node_to_prod_id[node_to_expand] = prod_id
                    picked_rhs_expansion_info.expansion_logprob[0] = picked_rhs_expansion_info.expansion_logprob[0] + np.log(prod_probability)
                    # print("Expanding %i using rule %s -> %s with prob %.3f in %s (tree prob %.3f)."
                    #       % (node_to_expand, type_to_expand, prod_rhs, prod_probability,
                    #          " ".join(get_tokens_from_expansion(expansion_info, root_node)),
                    #          np.exp(expansion_info.expansion_logprob[0])))
                    # Declare all children
                    for child_node_type in prod_rhs:
                        child_node_id = declare_new_node(picked_rhs_expansion_info, node_to_expand, child_node_type)
                        # print("  Child %i (type %s)" % (child_node_id, child_node_type))
                    # Mark the children as expansions. As we do depth-first, push them to front of the queue; and as
                    # extendleft reverses the order and we do left-to-right, reverse that reversal:
                    picked_rhs_expansion_info.nodes_to_expand.extendleft(reversed(picked_rhs_expansion_info.node_to_children[node_to_expand]))
                    expansions.append(picked_rhs_expansion_info)
            elif type_to_expand == VARIABLE_NONTERMINAL:
                # Case choose variable name.
                if len(expansion_info.variable_to_last_use_id.keys()) > 1:  # Only continue if at least one var is in scope (not just LAST_USED_TOKEN_NAME)
                    for (child_label, var_probability) in sample_variable(expansion_info, node_to_expand):
                        # print("Expanding %i by using variable %s with prob %.3f in %s (tree prob %.3f)."
                        #       % (node_to_expand, child_label, var_probability,
                        #          " ".join(get_tokens_from_expansion(expansion_info, root_node)),
                        #          np.exp(expansion_info.expansion_logprob[0])))
                        child_expansion_info = clone_expansion_info(expansion_info)
                        child_expansion_info.node_to_synthesised_attr_node[node_to_expand] = node_to_expand  # synthesised and inherited are the same for leafs
                        child_expansion_info.node_to_label[node_to_expand] = child_label
                        self.compute_incoming_edges(self.metadata['eg_production_vocab'].keys(), child_expansion_info, node_to_expand)  # This needs to be done now before we update the variable-to-last-use info
                        child_expansion_info.expansion_logprob[0] = child_expansion_info.expansion_logprob[0] + np.log(var_probability)
                        if self.hyperparameters['eg_update_last_variable_use_representation']:
                            child_expansion_info.variable_to_last_use_id[child_label] = node_to_expand
                        child_expansion_info.variable_to_last_use_id[LAST_USED_TOKEN_NAME] = node_to_expand
                        expansions.append(child_expansion_info)
            elif type_to_expand in LITERAL_NONTERMINALS:
                for (picked_literal, literal_probability) in sample_literal(expansion_info, node_to_expand):
                    # print("Expanding %i by using literal %s with prob %.3f in %s (tree prob %.3f)."
                    #       % (node_to_expand, picked_literal, literal_probability,
                    #          " ".join(get_tokens_from_expansion(expansion_info, root_node)),
                    #          np.exp(expansion_info.expansion_logprob[0])))
                    picked_literal_expansion_info = clone_expansion_info(expansion_info)
                    picked_literal_expansion_info.node_to_synthesised_attr_node[node_to_expand] = node_to_expand  # synthesised and inherited are the same for leafs
                    picked_literal_expansion_info.node_to_label[node_to_expand] = picked_literal
                    self.compute_incoming_edges(self.metadata['eg_production_vocab'].keys(), picked_literal_expansion_info,
                                                node_to_expand)
                    picked_literal_expansion_info.expansion_logprob[0] = picked_literal_expansion_info.expansion_logprob[0] + np.log(literal_probability)
                    picked_literal_expansion_info.variable_to_last_use_id[LAST_USED_TOKEN_NAME] = node_to_expand
                    expansions.append(picked_literal_expansion_info)
            else:
                # Case node is a terminal: Do nothing
                # print("Handling leaf node %i (label %s)" % (node_to_expand, expansion_info.node_to_label[node_to_expand]))
                expansion_info.node_to_synthesised_attr_node[node_to_expand] = node_to_expand  # synthesised and inherited are the same for leafs
                self.compute_incoming_edges(self.metadata['eg_production_vocab'].keys(), expansion_info, node_to_expand)  # This needs to be done now before we update the variable-to-last-use info
                expansion_info.variable_to_last_use_id[LAST_USED_TOKEN_NAME] = node_to_expand
                expansions = [expansion_info]

            return expansions

        if self.hyperparameters['eg_use_literal_copying']:
            literal_production_choice_normalizer = {}
            for literal_kind in LITERAL_NONTERMINALS:
                # Collect all choices (vocab + things we can copy from):
                literal_vocab = self.metadata['eg_literal_vocabs'][literal_kind]
                literal_choices = literal_vocab.id_to_token + context_tokens
                first_tok_occurrences = {}
                num_choices = len(literal_vocab) + self.hyperparameters['eg_max_context_tokens']
                normalizer_map = np.arange(num_choices, dtype=np.int16)
                for (token_idx, token) in enumerate(literal_choices):
                    first_occ = first_tok_occurrences.get(token)
                    if first_occ is not None:
                        normalizer_map[token_idx] = first_occ
                    else:
                        first_tok_occurrences[token] = token_idx
                literal_production_choice_normalizer[literal_kind] = normalizer_map
        else:
            literal_production_choice_normalizer = None

        root_node = test_sample['eg_root_node']
        initial_variable_to_last_use_id = test_sample['eg_variable_eg_node_ids']
        initial_variable_to_last_use_id[LAST_USED_TOKEN_NAME] = test_sample['eg_last_token_eg_node_id']
        initial_node_to_representation = {node_id: initial_eg_node_representations[node_id]
                                          for node_id in initial_variable_to_last_use_id.values()}
        initial_node_to_representation[root_node] = initial_eg_node_representations[root_node]

        initial_info = ExpansionInformation(node_to_type={root_node: ROOT_NONTERMINAL},
                                            node_to_label={root_node: ROOT_NONTERMINAL},
                                            node_to_prod_id={},
                                            node_to_children=defaultdict(list),
                                            node_to_parent={},
                                            node_to_synthesised_attr_node={node_id: node_id for node_id in initial_node_to_representation.keys()},
                                            node_to_inherited_attr_node={},
                                            variable_to_last_use_id=initial_variable_to_last_use_id,
                                            node_to_representation=initial_node_to_representation,
                                            node_to_labeled_incoming_edges={root_node: defaultdict(list)},
                                            node_to_unlabeled_incoming_edges={root_node: defaultdict(list)},
                                            context_token_representations=context_token_representations,
                                            context_token_mask=context_token_mask,
                                            context_tokens=context_tokens,
                                            literal_production_choice_normalizer=literal_production_choice_normalizer,
                                            nodes_to_expand=deque([root_node]),
                                            expansion_logprob=[0.0],
                                            num_expansions=0)

        beams = [initial_info]
        while any(len(b.nodes_to_expand) > 0 for b in beams):
            new_beams = [new_beam
                         for beam in beams
                         for new_beam in expand_node(beam)]
            beams = sorted(new_beams, key=lambda b: -b.expansion_logprob[0])[:beam_size]  # Pick top K beams

        self.test_log("Groundtruth: %s" % (" ".join(test_sample['eg_tokens']),))

        all_predictions = []  # type: List[Tuple[List[str], float]]
        for (k, beam_info) in enumerate(beams):
            kth_result = get_tokens_from_expansion(beam_info, root_node)
            all_predictions.append((kth_result, np.exp(beam_info.expansion_logprob[0])))
            self.test_log("  @%i Prob. %.3f: %s" % (k+1, np.exp(beam_info.expansion_logprob[0]), " ".join(kth_result)))

        if len(beams) == 0:
            self.test_log("No beams finished!")

        return ModelTestResult(test_sample['eg_tokens'], all_predictions)