def sparse_gnn_layer()

in python/dpu_utils/tfmodels/sparsegnn.py [0:0]


    def sparse_gnn_layer(self,
                         dropout_keep_rate: tf.Tensor,
                         node_embeddings: tf.Tensor,
                         adjacency_lists: List[tf.Tensor],
                         num_incoming_edges_per_type: Optional[tf.Tensor],
                         num_outgoing_edges_per_type: Optional[tf.Tensor],
                         edge_features: Dict[int, tf.Tensor]) -> tf.Tensor:
        """
        Run through a GNN and return the representations of the nodes.
        :param dropout_keep_rate: See name.
        :param node_embeddings: the initial embeddings of the nodes.
        :param adjacency_lists: a list of *sorted* adjacency indexes per edge type
        :param num_incoming_edges_per_type: [v, num_edge_types] tensor indicating number of incoming edges per type.
                                            Required if use_edge_bias or use_edge_msg_avg_aggregation is true.
        :param num_outgoing_edges_per_type: [v, num_edge_types] tensor indicating number of incoming edges per type.
                                            Required if add_backwards_edges and (use_edge_bias or use_edge_msg_avg_aggregation) is true.
        :param edge_features: a dictionary of edge_type -> num_edges x feature_length for the edges that have features.
        :return: the representations of the nodes
        """
        # Used shape abbreviations:
        #   V ~ number of nodes
        #   D ~ state dimension
        #   E ~ number of edges of current type
        #   M ~ number of messages (sum of all E)
        message_targets = []  # list of tensors of message targets of shape [E]
        message_edge_types = []  # list of tensors of edge type of shape [E]

        # Note that we optionally support adding (implicit) backwards edges. If turned on, we introduce additional
        # edge type indices [self.num_edge_types .. 2*self.num_edge_types - 1], with their own weights.

        for edge_type_idx, adjacency_list_for_edge_type in enumerate(adjacency_lists):
            edge_targets = adjacency_list_for_edge_type[:, 1]
            message_targets.append(edge_targets)
            message_edge_types.append(tf.ones_like(edge_targets, dtype=tf.int32) * edge_type_idx)
        if self.params['add_backwards_edges']:
            for edge_type_idx, adjacency_list_for_edge_type in enumerate(adjacency_lists):
                edge_targets = adjacency_list_for_edge_type[:, 0]
                message_targets.append(edge_targets)
                message_edge_types.append(tf.ones_like(edge_targets, dtype=tf.int32) * (self.num_edge_types + edge_type_idx ))
        message_targets = tf.concat(message_targets, axis=0)  # Shape [M]
        message_edge_types = tf.concat(message_edge_types, axis=0)  # Shape [M]

        with tf.variable_scope('gnn_scope'):
            node_states_per_layer = []  # list of tensors of shape [V, D], one entry per layer (the final state of that layer)
            node_states_per_layer.append(node_embeddings)
            num_nodes = tf.shape(node_embeddings, out_type=tf.int32)[0]

            for (layer_idx, num_timesteps) in enumerate(self.params['layer_timesteps']):
                with tf.variable_scope('gnn_layer_%i' % layer_idx):
                    # Extract residual messages, if any:
                    layer_residual_connections = self.params['residual_connections'].get(str(layer_idx))
                    if layer_residual_connections is None:
                        layer_residual_states = []
                    else:
                        layer_residual_states = [node_states_per_layer[residual_layer_idx]
                                                 for residual_layer_idx in layer_residual_connections]

                    if self.params['use_propagation_attention']:
                        message_edge_type_factors = tf.nn.embedding_lookup(params=self.__weights.edge_type_attention_weights[layer_idx],
                                                                           ids=message_edge_types)  # Shape [M]

                    # Record new states for this layer. Initialised to last state, but will be updated below:
                    node_states_per_layer.append(node_states_per_layer[-1])

                    for step in range(num_timesteps):
                        with tf.variable_scope('timestep_%i' % step):
                            messages = []  # list of tensors of messages of shape [E, D]
                            message_source_states = []  # list of tensors of edge source states of shape [E, D]

                            # Collect incoming messages per edge type
                            def compute_messages_for_edge_type(data_edge_type_idx: int, weights_edge_type_idx: int, edge_sources: tf.Tensor) -> None:
                                edge_source_states = tf.nn.embedding_lookup(params=node_states_per_layer[-1],
                                                                            ids=edge_sources)  # Shape [E, D]
                                edge_weights = tf.nn.dropout(self.__weights.edge_weights[layer_idx][weights_edge_type_idx],
                                                             rate=1-dropout_keep_rate)
                                all_messages_for_edge_type = tf.matmul(edge_source_states, edge_weights)  # Shape [E, D]

                                if data_edge_type_idx in edge_features:
                                    edge_feature_augmented = tf.concat([edge_features[data_edge_type_idx],
                                                                        1 / (edge_features[data_edge_type_idx] + SMALL_NUMBER)],
                                                                       axis=-1)  # Shape [E, 2*edge_size]
                                    all_messages_gate_value = \
                                        tf.sigmoid(self.__weights.edge_feature_gate_bias[layer_idx][weights_edge_type_idx]
                                                   + tf.matmul(edge_feature_augmented,
                                                               self.__weights.edge_feature_gate_weights[layer_idx][weights_edge_type_idx]))  # Shape [E, 1]
                                    all_messages_for_edge_type = all_messages_gate_value * all_messages_for_edge_type

                                messages.append(all_messages_for_edge_type)
                                message_source_states.append(edge_source_states)

                            for edge_type_idx, adjacency_list_for_edge_type in enumerate(adjacency_lists):
                                compute_messages_for_edge_type(edge_type_idx, edge_type_idx, adjacency_list_for_edge_type[:, 0])
                            if self.params['add_backwards_edges']:
                                for edge_type_idx, adjacency_list_for_edge_type in enumerate(adjacency_lists):
                                    compute_messages_for_edge_type(edge_type_idx, self.num_edge_types + edge_type_idx, adjacency_list_for_edge_type[:, 1])

                            messages = tf.concat(messages, axis=0)  # Shape [M, D]

                            if self.params['use_propagation_attention']:
                                message_source_states = tf.concat(message_source_states, axis=0)  # Shape [M, D]
                                message_target_states = tf.nn.embedding_lookup(params=node_states_per_layer[-1],
                                                                               ids=message_targets)  # Shape [M, D]
                                message_attention_scores = tf.einsum('mi,mi->m', message_source_states, message_target_states)  # Shape [M]
                                message_attention_scores = message_attention_scores * message_edge_type_factors

                                message_log_attention = unsorted_segment_log_softmax(logits=message_attention_scores,
                                                                                 segment_ids=message_targets,
                                                                                 num_segments=num_nodes)

                                message_attention = tf.exp(message_log_attention) # Shape [M]
                                # Step (4): Weight messages using the attention prob:
                                messages = messages * tf.expand_dims(message_attention, -1)

                            incoming_messages = self.unsorted_segment_aggregation_func(data=messages,
                                                                                  segment_ids=message_targets,
                                                                                  num_segments=num_nodes)  # Shape [V, D]

                            if self.params['use_edge_bias']:
                                incoming_messages += tf.matmul(num_incoming_edges_per_type,
                                                               self.__weights.edge_biases[layer_idx][0:self.num_edge_types])  # Shape [V, D]
                                if self.params['add_backwards_edges']:
                                    incoming_messages += tf.matmul(num_outgoing_edges_per_type,
                                                                   self.__weights.edge_biases[layer_idx][self.num_edge_types:])  # Shape [V, D]

                            if self.params['use_edge_msg_avg_aggregation']:
                                num_incoming_edges = tf.reduce_sum(num_incoming_edges_per_type,
                                                                   keep_dims=True, axis=-1)  # Shape [V, 1]
                                if self.params['add_backwards_edges']:
                                    num_incoming_edges += tf.reduce_sum(num_outgoing_edges_per_type,
                                                                        keep_dims=True, axis=-1)  # Shape [V, 1]
                                incoming_messages /= num_incoming_edges + SMALL_NUMBER

                            incoming_information = tf.concat(layer_residual_states + [incoming_messages],
                                                             axis=-1)  # Shape [V, D*(1 + num of residual connections)]

                            # pass updated vertex features into RNN cell
                            node_states_per_layer[-1] = self.__weights.rnn_cells[layer_idx](incoming_information,
                                                                                            node_states_per_layer[-1])[1]  # Shape [V, D]

        return node_states_per_layer[-1]