def forward()

in hugegraph-ml/src/hugegraph_ml/models/arma.py [0:0]


    def forward(self, g, feats):
        with g.local_scope():
            init_feats = feats
            # assume that the graphs are undirected and graph.in_degrees() is the same as graph.out_degrees()
            degs = g.in_degrees().float().clamp(min=1)
            norm = torch.pow(degs, -0.5).to(feats.device).unsqueeze(1)
            output = []

            for k in range(self.K):
                feats = init_feats
                for t in range(self.T):
                    feats = feats * norm
                    g.ndata["h"] = feats
                    g.update_all(fn.copy_u("h", "m"), fn.sum("m", "h")) # pylint: disable=E1101
                    feats = g.ndata.pop("h")
                    feats = feats * norm

                    if t == 0:
                        feats = self.w_0[str(k)](feats)
                    else:
                        feats = self.w[str(k)](feats)

                    feats += self.dropout(self.v[str(k)](init_feats))
                    feats += self.v[str(k)](self.dropout(init_feats))

                    if self.bias is not None:
                        feats += self.bias[k][t]

                    if self.activation is not None:
                        feats = self.activation(feats)
                output.append(feats)

            return torch.stack(output).mean(dim=0)