in hugegraph-ml/src/hugegraph_ml/models/arma.py [0:0]
def forward(self, g, feats):
with g.local_scope():
init_feats = feats
# assume that the graphs are undirected and graph.in_degrees() is the same as graph.out_degrees()
degs = g.in_degrees().float().clamp(min=1)
norm = torch.pow(degs, -0.5).to(feats.device).unsqueeze(1)
output = []
for k in range(self.K):
feats = init_feats
for t in range(self.T):
feats = feats * norm
g.ndata["h"] = feats
g.update_all(fn.copy_u("h", "m"), fn.sum("m", "h")) # pylint: disable=E1101
feats = g.ndata.pop("h")
feats = feats * norm
if t == 0:
feats = self.w_0[str(k)](feats)
else:
feats = self.w[str(k)](feats)
feats += self.dropout(self.v[str(k)](init_feats))
feats += self.v[str(k)](self.dropout(init_feats))
if self.bias is not None:
feats += self.bias[k][t]
if self.activation is not None:
feats = self.activation(feats)
output.append(feats)
return torch.stack(output).mean(dim=0)