def forward()

in ppuda/ghn/gatedgnn.py [0:0]


    def forward(self, x, edges, node_graph_ind):
        r"""
        Updates node features by sequentially traversing the graph in the forward and backward directions.
        :param x: (N, C) node features, where N is the total number of nodes in a batch of B graphs, C is node feature dimensionality.
        :param edges: (M, 4) tensor of edges, where M is the total number of edges;
                       first column in edges is the row indices of edges,
                       second column in edges is the column indices of edges,
                       third column in edges is the shortest path distance between the nodes,
                       fourth column in edges is the graph indices (from 0 to B-1) within a batch for each edge.
        :param node_graph_ind: (N,) tensor of graph indices (from 0 to B-1) within a batch for each node.
        :return: updated (N, C) node features.
        """

        assert x.dim() == 2 and edges.dim() == 2 and edges.shape[1] == 4, (x.shape, edges.shape)
        n_nodes = torch.unique(node_graph_ind, return_counts=True)[1]

        B, C = len(n_nodes), x.shape[1]  # batch size, features

        ve, edge_graph_ind = edges[:, 2], edges[:, 3]

        assert n_nodes.sum() == len(x), (n_nodes.sum(), x.shape)

        is_1hop = ve == 1
        if self.ve:
            ve = ve.view(-1, 1)   # according to Eq. 4 in the paper

        traversal_orders = [1, 0]  # forward, backward

        edge_offset = torch.cumsum(F.pad(n_nodes[:-1], (1, 0)), 0)[edge_graph_ind]
        node_inds = torch.cat([torch.arange(n) for n in n_nodes]).to(x.device).view(-1, 1)

        # Parallelize computation of indices and masks of one/all hops
        # This will slightly speed up the operations in the main loop
        # But indexing of the GPU tensors (used in the main loop) for some reason remains slow, see
        # https://github.com/pytorch/pytorch/issues/29973 for more info
        all_nodes = torch.arange(edges[:, 1].max() + 1, device=x.device)
        masks_1hop, masks_all = {}, {}
        for order in traversal_orders:
            masks_all[order] = edges[:, order].view(1, -1) == all_nodes.view(-1, 1)
            masks_1hop[order] = masks_all[order] & is_1hop.view(1, -1)
        mask2d = node_inds == all_nodes.view(1, -1)
        edge_graph_ind = edge_graph_ind.view(-1, 1).expand(-1, C)

        hx = x  # initial hidden node features

        # Main loop
        for t in range(self.T):
            for order in traversal_orders:  # forward, backward
                start = edges[:, 1 - order] + edge_offset                           # node indices from which the message will be passed further
                for node in (all_nodes if order else torch.flipud(all_nodes)):

                    # Compute the message by aggregating features from neighbors
                    e_1hop = torch.nonzero(masks_1hop[order][node, :]).view(-1)
                    m = self.mlp(hx[start[e_1hop]])                                 # transform node features of all 1-hop neighbors
                    m = torch.zeros(B, C, dtype=m.dtype, device=m.device).scatter_add_(0, edge_graph_ind[e_1hop], m)     # sum the transformed features into a (B,C) tensor
                    if self.ve:
                        e = torch.nonzero(masks_all[order][node, :]).view(-1)       # virtual edges connected to node
                        m_ve = self.mlp_ve(hx[start[e]]) / ve[e].to(m)              # transform node features of all ve-hop neighbors
                        m = m.scatter_add_(0, edge_graph_ind[e], m_ve)              # sum m and m_ve according to Eq. 4 in the paper

                    # Udpate node hidden states in parallel for a batch of graphs
                    ind = torch.nonzero(mask2d[:, node]).view(-1)
                    if B > 1:
                        m = m[node_graph_ind[ind]]
                    hx[ind] = self.gru(m, hx[ind]).to(hx)  # 'to(hx)' is to make automatic mixed precision work

        return hx