def __load_expansiongraph_training_data_from_sample()

in Models/exprsynth/nagdecoder.py [0:0]


    def __load_expansiongraph_training_data_from_sample(
            hyperparameters: Dict[str, Any], metadata: Dict[str, Any],
            raw_sample: Dict[str, Any], prod_root_node: int, node_to_inherited_id: Dict[int, int],
            node_to_synthesised_id: Dict[int, int], last_used_node_id: Dict[str, int],
            result_holder: Dict[str, Any]) -> None:
        # Shortcuts and temp data we use during construction:
        symbol_to_kind = raw_sample['SymbolKinds']  # type: Dict[str, str]
        symbol_to_prod = raw_sample['Productions']  # type: Dict[str, List[int]]
        symbol_to_label = raw_sample['SymbolLabels']  # type: Dict[str, str]
        variables_in_scope = list(sorted(raw_sample['LastUseOfVariablesInScope'].keys()))  # type: List[str]

        # These are the things we'll use in the end:
        eg_node_id_to_prod_id = []  # type: List[Tuple[int, int]]  # Pairs of (node id, chosen production id)
        eg_node_id_to_varchoice = []  # type: List[Tuple[int, List[int], int]]  # Triples of (node id, [id of last var use], index of correct var)
        eg_node_id_to_literal_choice = defaultdict(list)  # type: Dict[str, List[Tuple[int, int]]]  # Maps literal kind to pairs of (node id, chosen literal id)
        eg_prod_node_to_var_last_uses = {}  # type: Dict[int, np.ndarray]  # Dict from production node id to [id of last var use]
        eg_schedule = []  # type: List[Dict[str, List[Tuple[int, int, Optional[int]]]]]  # edge type name to edges of that type in each expansion step.

        # We will use these to pick literals:
        eg_literal_tok_to_idx = {}
        if hyperparameters['eg_use_literal_copying']:
            eg_literal_choice_normalizer_maps = {}
            # For each literal kind, we compute a "normalizer map", which we use to identify
            # choices that correspond to the same literal (e.g., when using a literal several
            # times in context)
            for literal_kind in LITERAL_NONTERMINALS:
                # Collect all choices (vocab + things we can copy from):
                literal_vocab = metadata['eg_literal_vocabs'][literal_kind]
                literal_choices = \
                    literal_vocab.id_to_token \
                    + result_holder['context_nonkeyword_tokens'][:hyperparameters['eg_max_context_tokens']]
                first_tok_occurrences = {}
                num_choices = hyperparameters['eg_max_context_tokens'] + len(literal_vocab)
                normalizer_map = np.arange(num_choices, dtype=np.int16)
                for (token_idx, token) in enumerate(literal_choices):
                    first_occ = first_tok_occurrences.get(token)
                    if first_occ is not None:
                        normalizer_map[token_idx] = first_occ
                    else:
                        first_tok_occurrences[token] = token_idx
                eg_literal_tok_to_idx[literal_kind] = first_tok_occurrences
                eg_literal_choice_normalizer_maps[literal_kind] = normalizer_map
            result_holder['eg_literal_choice_normalizer_maps'] = eg_literal_choice_normalizer_maps
        else:
            for literal_kind in LITERAL_NONTERMINALS:
                eg_literal_tok_to_idx[literal_kind] = metadata['eg_literal_vocabs'][literal_kind].token_to_id

        # Map prods onto internal numbering and compute propagation schedule:
        def declare_new_node(sym_exp_node_id: int, is_synthesised: bool) -> int:
            new_node_id = len(result_holder['eg_node_labels'])
            node_label = symbol_to_label.get(str(sym_exp_node_id)) or symbol_to_kind[str(sym_exp_node_id)]
            result_holder['eg_node_labels'].append(metadata['eg_token_vocab'].get_id_or_unk(node_label))
            if is_synthesised:
                node_to_synthesised_id[sym_exp_node_id] = new_node_id
            else:
                node_to_inherited_id[sym_exp_node_id] = new_node_id

            return new_node_id

        def expand_node(node_id: int) -> None:
            rhs_node_ids = symbol_to_prod.get(str(node_id))
            if rhs_node_ids is None:
                # In the case that we have no children, the downwards and upwards version of our node are the same:
                node_to_synthesised_id[node_id] = node_to_inherited_id[node_id]
                return

            declare_new_node(node_id, is_synthesised=True)

            # Figure out which production id this one is is and store it away:
            rhs = raw_rhs_to_tuple(symbol_to_kind, symbol_to_label, rhs_node_ids)
            known_lhs_productions = metadata['eg_production_vocab'][symbol_to_kind[str(node_id)]]
            if rhs not in known_lhs_productions:
                raise MissingProductionException("%s -> %s" % (symbol_to_kind[str(node_id)], rhs))
            prod_id = known_lhs_productions[rhs]
            eg_node_id_to_prod_id.append((node_to_inherited_id[node_id], prod_id))
            if hyperparameters['eg_use_vars_for_production_choice']:
                eg_prod_node_to_var_last_uses[node_to_inherited_id[node_id]] = \
                    np.array([node_to_synthesised_id[last_used_node_id[varchoice]] for varchoice in variables_in_scope], dtype=np.int16)
            # print("Expanding %i using rule %i %s -> %s" % (node_to_inherited_id[node_id], prod_id,
            #                                                symbol_to_label.get(str(node_id)) or symbol_to_kind[str(node_id)],
            #                                                tuple([symbol_to_kind[str(rhs_node_id)] for rhs_node_id in rhs_node_ids])))

            # Visit all children, in left-to-right order, descending into them if needed
            last_sibling = None
            parent_inwards_edges = defaultdict(list)  # type: Dict[str, List[Tuple[int, int, Optional[int]]]]
            parent_inwards_edges['InheritedToSynthesised'].append((node_to_inherited_id[node_id],
                                                                   node_to_synthesised_id[node_id],
                                                                   None))
            for (rhs_symbol_idx, child_id) in enumerate(rhs_node_ids):
                child_inherited_id = declare_new_node(child_id, is_synthesised=False)
                child_inwards_edges = defaultdict(list)  # type: Dict[str, List[Tuple[int, int, Optional[int]]]]
                child_edge_label_id = metadata['eg_edge_label_vocab'][(prod_id, rhs_symbol_idx)]
                child_inwards_edges['Child'].append((node_to_inherited_id[node_id], child_inherited_id, child_edge_label_id))

                # Connection from left sibling, and prepare to be connected to the right sibling:
                if last_sibling is not None:
                    child_inwards_edges['NextSibling'].append((node_to_synthesised_id[last_sibling], child_inherited_id, None))
                last_sibling = child_id

                # Connection from the last generated leaf ("next action" in "A syntactic neural model for general-purpose code generation", Yin & Neubig '17):
                child_inwards_edges['NextSubtree'].append((node_to_synthesised_id[last_used_node_id[LAST_USED_TOKEN_NAME]],
                                                           child_inherited_id,
                                                           None))

                # Check if we are terminal (token) node and add appropriate edges if that's the case:
                if str(child_id) not in symbol_to_prod:
                    child_inwards_edges['NextToken'].append((node_to_synthesised_id[last_used_node_id[LAST_USED_TOKEN_NAME]],
                                                             child_inherited_id,
                                                             None))
                    last_used_node_id[LAST_USED_TOKEN_NAME] = child_id

                # If we are a variable or literal, we also need to store information to train to make the right choice:
                child_kind = symbol_to_kind[str(child_id)]
                if child_kind == VARIABLE_NONTERMINAL:
                    var_name = symbol_to_label[str(child_id)]
                    # print("  Chose variable %s" % var_name)
                    last_var_use_id = last_used_node_id[var_name]
                    cur_variable_to_last_use_ids = [node_to_synthesised_id[last_used_node_id[varchoice]] for varchoice in variables_in_scope]
                    varchoice_id = variables_in_scope.index(var_name)
                    eg_node_id_to_varchoice.append((node_to_inherited_id[node_id], cur_variable_to_last_use_ids, varchoice_id))
                    child_inwards_edges['NextUse'].append((node_to_synthesised_id[last_var_use_id], child_inherited_id, None))
                    if hyperparameters['eg_update_last_variable_use_representation']:
                        last_used_node_id[var_name] = child_id
                elif child_kind in LITERAL_NONTERMINALS:
                    literal = symbol_to_label[str(child_id)]
                    # print("  Chose literal %s" % literal)
                    literal_id = eg_literal_tok_to_idx[child_kind].get(literal)
                    # In the case that a literal is not in the vocab and not in the context, the above will return None,
                    # so map that explicitly to the id for UNK:
                    if literal_id is None:
                        literal_id = metadata['eg_literal_vocabs'][child_kind].get_id_or_unk(literal)
                    eg_node_id_to_literal_choice[child_kind].append((node_to_inherited_id[node_id], literal_id))

                # Store the edges leading to new node, recurse into it, and mark its upwards connection for later:
                eg_schedule.append(child_inwards_edges)
                expand_node(child_id)
                parent_inwards_edges['Parent'].append((node_to_synthesised_id[child_id], node_to_synthesised_id[node_id], None))
            eg_schedule.append(parent_inwards_edges)

        expand_node(prod_root_node)

        expansion_labeled_edge_types, expansion_unlabeled_edge_types = get_restricted_edge_types(hyperparameters)

        def split_schedule_step(step: Dict[str, List[Tuple[int, int, Optional[int]]]]) -> List[List[Tuple[int, int]]]:
            total_edge_types = len(expansion_labeled_edge_types) + len(expansion_unlabeled_edge_types)
            step_by_edge = [[] for _ in range(total_edge_types)]  # type: List[List[Tuple[int, int]]]
            for (label, edges) in step.items():
                edges = [(v, w) for (v, w, _) in edges]  # Strip off (optional) label:
                if label in expansion_labeled_edge_types:
                    step_by_edge[expansion_labeled_edge_types[label]] = edges
                elif label in expansion_unlabeled_edge_types:
                    step_by_edge[len(expansion_labeled_edge_types) + expansion_unlabeled_edge_types[label]] = edges
            return step_by_edge

        def edge_labels_from_schedule_step(step: Dict[str, List[Tuple[int, int, Optional[int]]]]) -> List[List[int]]:
            labels_by_edge = [[] for _ in range(len(expansion_labeled_edge_types))]  # type: List[List[int]]
            for (label, edges) in step.items():
                if label in expansion_labeled_edge_types:
                    label_ids = [l for (_, _, l) in edges]  # Keep only edge label
                    labels_by_edge[expansion_labeled_edge_types[label]] = label_ids
            return labels_by_edge

        # print("Schedule:")
        # initialised_nodes = set()
        # initialised_nodes = initialised_nodes | result_holder['eg_node_id_to_cg_node_id'].keys()
        # for step_id, expansion_step in enumerate(eg_schedule):
        #     print(" Step %i" % step_id)
        #     initialised_this_step = set()
        #     for edge_type in EXPANSION_UNLABELED_EDGE_TYPE_NAMES + EXPANSION_LABELED_EDGE_TYPE_NAMES:
        #         for (v, w, _) in expansion_step[edge_type]:
        #             assert v in initialised_nodes
        #             assert w not in initialised_nodes
        #             initialised_this_step.add(w)
        #     for newly_computed_node in initialised_this_step:
        #         node_label_id = result_holder['eg_node_labels'][newly_computed_node]
        #         print("   Node Label for %i: %i (reversed %s)"
        #               % (newly_computed_node, node_label_id, metadata['eg_token_vocab'].id_to_token[node_label_id]))
        #     for edge_type in EXPANSION_UNLABELED_EDGE_TYPE_NAMES + EXPANSION_LABELED_EDGE_TYPE_NAMES:
        #         edges = expansion_step[edge_type]
        #         if len(edges) > 0:
        #             initialised_nodes = initialised_nodes | initialised_this_step
        #             print("   %s edges: [%s]" % (edge_type,
        #                                          ", ".join("(%s -[%s]> %s)" % (v, l, w) for (v, w, l) in edges)))
        # print("Variable choices:\n %s" % (str(eg_node_id_to_varchoice)))
        # print("Literal choices: \n %s" % (str(eg_node_id_to_literal_choice)))
        if hyperparameters['eg_use_vars_for_production_choice']:
            result_holder['eg_production_node_id_to_var_last_use_node_ids'] = eg_prod_node_to_var_last_uses
        result_holder['eg_node_id_to_prod_id'] = np.array(eg_node_id_to_prod_id, dtype=np.int16)
        result_holder['eg_node_id_to_varchoice'] = eg_node_id_to_varchoice
        result_holder['eg_node_id_to_literal_choice'] = {}
        for literal_kind in LITERAL_NONTERMINALS:
            literal_choice_data = eg_node_id_to_literal_choice.get(literal_kind)
            if literal_choice_data is None:
                literal_choice_data = np.empty(shape=[0, 2], dtype=np.uint16)
            else:
                literal_choice_data = np.array(literal_choice_data, dtype=np.uint16)
            result_holder['eg_node_id_to_literal_choice'][literal_kind] = literal_choice_data
        result_holder['eg_schedule'] = [split_schedule_step(step) for step in eg_schedule]
        result_holder['eg_edge_label_ids'] = [edge_labels_from_schedule_step(step) for step in eg_schedule]