def async_ggnn_layer()

in python/dpu_utils/tfmodels/asyncgnn.py [0:0]


    def async_ggnn_layer(self,
                         initial_node_representation: tf.Tensor,
                         initial_nodes: List[tf.Tensor],
                         sending_nodes: List[List[List[tf.Tensor]]],
                         edge_labels: List[List[List[tf.Tensor]]],
                         msg_targets: List[List[tf.Tensor]],
                         receiving_nodes: List[List[tf.Tensor]],
                         receiving_node_num: List[tf.Tensor]) -> tf.Tensor:
        """
        Run through an async GGNN and return the representations of all nodes.
        :param initial_node_representation: the initial embeddings of the nodes.
                                            Shape: [-1, h_dim]
        :param initial_nodes: List of node id tensors I_{r}: Node IDs that will have no incoming edges in round r.
                              Inner Tensor Shape: [-1]
        :param sending_nodes: List of lists of lists of sending nodes S_{r,s,e}: Source node ids of edges of type e
                              propagating in step s of round r. By convention, 0..self.num_labeled_edges are labeled
                              edge types, and self.num_labeled_edges.. are unlabeled edge types.
                              Restrictions: If v in S_{r,s,e}, then v in R_{r,s'} for s' < s or v in I_{r}.
                              Inner Tensor Shape: [-1]
        :param edge_labels: List of lists of lists of (embeddings of) labels of edges L_{r,s,e}: Labels of edges of type
                            e propagating in step s of round r.
                            Restrictions: len(L_{r,s,e}) = len(S_{r,s,e})
                            Inner Tensor Shape: [-1, e_dim]
        :param msg_targets: List of lists of normalised edge target nodes T_{r,s}: Targets of edges propagating in step
                            s of round r, normalised to a continuous range starting from 0.
                            This is used for aggregating messages from the sending nodes.
                            Inner Tensor Shape: [-1]
        :param receiving_nodes: List of lists of receiving nodes R_{r,s}: Target node ids of aggregated messages in
                                propagation step s of round r.
                                Restrictions: If v in R_{r,s}, v not in R_{r,s'} for all s' != s and v not in I_{r}.
                                Inner Tensor Shape: [-1]
        :param receiving_node_num: Number of receiving nodes N_{r,s}
                                   Restrictions: N_{r,s} = len(R_{r,s})
                                   Inner Tensor Shape: [|Substeps|]
        :return: representations of all nodes after propagation according to schedule. Shape: [-1, h_dim]
        """
        with tf.variable_scope('async_ggnn'):
            cur_node_states = initial_node_representation

            for prop_round in range(self.hyperparams['propagation_rounds']):
                with tf.variable_scope('prop_round%i' % (prop_round,)):
                    # ---- Declare and fill tensor arrays used in tf.while_loop:
                    sending_nodes_ta = tf.TensorArray(
                        tf.int32,
                        infer_shape=False,
                        element_shape=[None],
                        size=self.hyperparams['propagation_substeps'] * self.num_edge_types,
                        name='sending_nodes'
                    )
                    edge_labels_ta = tf.TensorArray(
                        tf.float32,
                        infer_shape=False,
                        element_shape=[None, self.hyperparams['edge_label_size']],
                        size=self.hyperparams['propagation_substeps'] * self.num_labeled_edge_types,
                        name='edge_labels'
                    )
                    msg_targets_ta = tf.TensorArray(tf.int32,
                                                    infer_shape=False,
                                                    element_shape=[None],
                                                    size=self.hyperparams['propagation_substeps'],
                                                    name='msg_targets')
                    receiving_nodes_ta = tf.TensorArray(tf.int32,
                                                        infer_shape=False,
                                                        element_shape=[None],
                                                        size=self.hyperparams['propagation_substeps'],
                                                        clear_after_read=False,
                                                        name='receiving_nodes')
                    receiving_node_num_ta = tf.TensorArray(tf.int32,
                                                           infer_shape=False,
                                                           element_shape=[],
                                                           size=self.hyperparams['propagation_substeps'],
                                                           name='receiving_nodes_num')

                    for step in range(self.hyperparams['propagation_substeps']):
                        for labeled_edge_typ in range(self.num_labeled_edge_types):
                            sending_nodes_ta = sending_nodes_ta.write(step * self.num_edge_types + labeled_edge_typ,
                                                                      sending_nodes[prop_round][step][labeled_edge_typ])
                            edge_labels_ta = edge_labels_ta.write(step * self.num_labeled_edge_types + labeled_edge_typ,
                                                                  edge_labels[prop_round][step][labeled_edge_typ])
                        for unlabeled_edge_typ in range(self.num_unlabeled_edge_types):
                            shifted_edge_typ = self.num_labeled_edge_types + unlabeled_edge_typ
                            sending_nodes_ta = sending_nodes_ta.write(step * self.num_edge_types + shifted_edge_typ,
                                                                      sending_nodes[prop_round][step][shifted_edge_typ])
                        msg_targets_ta = msg_targets_ta.write(step, msg_targets[prop_round][step])
                        receiving_nodes_ta = receiving_nodes_ta.write(step, receiving_nodes[prop_round][step])
                    receiving_node_num_ta = receiving_node_num_ta.unstack(receiving_node_num[prop_round])

                    new_node_states_ta = tf.TensorArray(tf.float32,
                                                        infer_shape=False,
                                                        element_shape=[self.hyperparams['hidden_size']],
                                                        size=tf.shape(cur_node_states)[0],
                                                        clear_after_read=False,
                                                        name='new_node_states')

                    # ---- Actual propagation schedule implementation:
                    # Initialize the initial nodes with their state from last round:
                    new_node_states_ta = new_node_states_ta.scatter(initial_nodes[prop_round],
                                                                    tf.gather(cur_node_states, initial_nodes[prop_round]))

                    def do_substep(substep_id, new_node_states_ta):
                        # For each edge active in this substep, pull source state and transform:
                        sending_states_per_edge_type = []
                        edge_labels_per_type = []
                        for labeled_edge_typ in range(self.num_labeled_edge_types):
                            sending_states_per_edge_type.append(
                                new_node_states_ta.gather(sending_nodes_ta.read(
                                    substep_id * self.num_edge_types + labeled_edge_typ
                                ))
                            )
                            edge_labels_per_type.append(edge_labels_ta.read(
                                substep_id * self.num_labeled_edge_types + labeled_edge_typ
                            ))
                        for unlabeled_edge_typ in range(self.num_unlabeled_edge_types):
                            shifted_edge_typ = self.num_labeled_edge_types + unlabeled_edge_typ
                            sending_states_per_edge_type.append(new_node_states_ta.gather(
                                sending_nodes_ta.read(substep_id * self.num_edge_types + shifted_edge_typ)
                            ))

                        # Collect old states for receiving nodes
                        substep_receiving_nodes = receiving_nodes_ta.read(substep_id)
                        old_receiving_node_states = tf.gather(cur_node_states, substep_receiving_nodes)
                        old_receiving_node_states.set_shape([None, self.hyperparams['hidden_size']])

                        msg_targets_this_step = msg_targets_ta.read(substep_id)
                        receiving_node_num_this_step = receiving_node_num_ta.read(substep_id)

                        substep_new_node_states = self.propagate_one_step(
                            sending_states_per_edge_type, edge_labels_per_type,
                            msg_targets_this_step, receiving_node_num_this_step,
                            old_receiving_node_states
                        )

                        # Write updated states back:
                        new_node_states_ta = new_node_states_ta.scatter(indices=substep_receiving_nodes,
                                                                        value=substep_new_node_states,
                                                                        name="state_scatter_round%i" % (prop_round,))
                        return substep_id + 1, new_node_states_ta

                    def is_done(substep_id, new_node_states_ta_unused):
                        return tf.logical_and(substep_id < self.hyperparams['propagation_substeps'],
                                              tf.greater(tf.shape(receiving_nodes_ta.read(substep_id))[0], 0))

                    _, new_node_states_ta = tf.while_loop(cond=is_done,
                                                          body=do_substep,
                                                          loop_vars=[tf.constant(0), new_node_states_ta]
                                                         )

                    cur_node_states = new_node_states_ta.stack(name="state_stack_round%i" % (prop_round,))

            return cur_node_states