def __load_contextgraph_data_from_sample()

in Models/exprsynth/contextgraphmodel.py [0:0]


    def __load_contextgraph_data_from_sample(hyperparameters: Dict[str, Any], metadata: Dict[str, Any],
                                             raw_sample: Dict[str, Any], result_holder: Dict[str, Any],
                                             is_train: bool=True) \
            -> bool:
        if hyperparameters.get('cg_add_subtoken_nodes', False):
            _add_per_subtoken_nodes(metadata['nag_reserved_names'], raw_sample)
        graph_node_labels = raw_sample['ContextGraph']['NodeLabels']
        graph_node_types = raw_sample['ContextGraph']['NodeTypes']
        num_nodes = len(graph_node_labels)
        if num_nodes >= hyperparameters['max_num_cg_nodes_in_batch']:
            print("Dropping example using %i nodes in context graph" % (num_nodes,))
            return False

        # Translate node label, either using the token vocab or into a character representation:
        if hyperparameters['cg_node_label_embedding_style'].lower() == 'token':
            # Translate node labels using the token vocabulary:
            node_labels = np.zeros((num_nodes,), dtype=np.uint16)
            for (node, label) in graph_node_labels.items():
                node_labels[int(node)] = metadata['cg_node_label_vocab'].get_id_or_unk(label)
            result_holder['cg_node_label_token_ids'] = node_labels
        elif hyperparameters['cg_node_label_embedding_style'].lower() == 'charcnn':
            # Translate node labels into character-based representation, and make unique per context graph:
            node_label_chars = np.zeros(shape=(num_nodes,
                                               hyperparameters['cg_node_label_char_length']),
                                        dtype=np.uint8)
            for (node, label) in graph_node_labels.items():
                for (char_idx, label_char) in enumerate(label[:hyperparameters['cg_node_label_char_length']].lower()):
                    node_label_chars[int(node), char_idx] = ALPHABET_DICT.get(label_char, 1)
            unique_label_chars, node_label_unique_indices = np.unique(node_label_chars,
                                                                      axis=0,
                                                                      return_inverse=True)
            result_holder['cg_unique_label_chars'] = unique_label_chars
            result_holder['cg_node_label_unique_indices'] = node_label_unique_indices
        else:
            raise Exception("Unknown node label embedding style '%s'!"
                            % hyperparameters['cg_node_label_embedding_style'])

        # Translate node types, include supertypes:
        max_num_types = hyperparameters['cg_node_type_max_num']
        node_type_labels = np.full((num_nodes, max_num_types),
                                   metadata['cg_node_type_vocab'].get_id_or_unk(NO_TYPE)[0], dtype=np.uint16)
        node_type_labels_mask = np.zeros((num_nodes, max_num_types), dtype=np.bool)
        node_type_labels_mask[:, 0] = True
        for node_id, token_type in graph_node_types.items():
            node_id = int(node_id)
            node_types = metadata['cg_node_type_vocab'].get_id_or_unk('type:' + token_type, metadata['type_lattice'])
            if is_train and len(node_types) > max_num_types:
                random.shuffle(node_types,
                               random=random.random)  # Shuffle the types so that we get a mixture of the type hierarchy and not always the same ones
            node_types = node_types[:max_num_types]
            num_types = len(node_types)
            node_type_labels[node_id, :num_types] = node_types
            node_type_labels_mask[node_id, :num_types] = True

        result_holder['cg_node_type_labels'] = node_type_labels
        result_holder['cg_node_type_labels_mask'] = node_type_labels_mask

        # Split edges according to edge_type and count their numbers:
        result_holder['cg_edges'] = [[] for _ in metadata['cg_edge_type_dict']]
        result_holder['cg_edge_values'] = {}
        num_edge_types = len(metadata['cg_edge_type_dict'])
        num_incoming_edges_per_type = np.zeros((num_nodes, num_edge_types), dtype=np.uint16)
        num_outgoing_edges_per_type = np.zeros((num_nodes, num_edge_types), dtype=np.uint16)
        for (e_type, e_type_idx) in metadata['cg_edge_type_dict'].items():
            if e_type in raw_sample['ContextGraph']['Edges']:
                edges = np.array(raw_sample['ContextGraph']['Edges'][e_type], dtype=np.int32)
                result_holder['cg_edges'][e_type_idx] = edges

                if e_type_idx in metadata['cg_edge_value_sizes']:
                    edge_values = np.array(raw_sample['ContextGraph']['EdgeValues'][e_type], dtype=np.float32)
                    edge_values = np.array(edge_values, dtype=np.float32)
                    result_holder['cg_edge_values'][e_type_idx] = \
                        np.clip(edge_values, -MAX_EDGE_VALUE, MAX_EDGE_VALUE) / MAX_EDGE_VALUE
            else:
                result_holder['cg_edges'][e_type_idx] = np.zeros((0, 2), dtype=np.int32)
                if e_type_idx in metadata['cg_edge_value_sizes']:
                    result_holder['cg_edge_values'][e_type_idx] = \
                        np.zeros((0, metadata['cg_edge_value_sizes'][e_type_idx]), dtype=np.float32)
            num_incoming_edges_per_type[:, e_type_idx] = np.bincount(result_holder['cg_edges'][e_type_idx][:, 1],
                                                                     minlength=num_nodes)
            num_outgoing_edges_per_type[:, e_type_idx] = np.bincount(result_holder['cg_edges'][e_type_idx][:, 0],
                                                                     minlength=num_nodes)
        result_holder['cg_num_incoming_edges_per_type'] = num_incoming_edges_per_type
        result_holder['cg_num_outgoing_edges_per_type'] = num_outgoing_edges_per_type

        return True