in hugegraph-ml/src/hugegraph_ml/models/correct_and_smooth.py [0:0]
def forward(self, g, labels, mask=None, post_step=lambda y: y.clamp_(0.0, 1.0)):
with g.local_scope():
if labels.dtype == torch.long:
labels = F.one_hot(labels.view(-1)).to(torch.float32)
y = labels
if mask is not None:
y = torch.zeros_like(labels)
y[mask] = labels[mask]
last = (1 - self.alpha) * y
degs = g.in_degrees().float().clamp(min=1)
norm = (
torch.pow(degs, -0.5 if self.adj == "DAD" else -1)
.to(labels.device)
.unsqueeze(1)
)
for _ in range(self.num_layers):
# Assume the graphs to be undirected
if self.adj in ["DAD", "AD"]:
y = norm * y
g.ndata["h"] = y
g.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
y = self.alpha * g.ndata.pop("h")
if self.adj in ["DAD", "DA"]:
y = y * norm
y = post_step(last + y)
return y