def forward()

in hugegraph-ml/src/hugegraph_ml/models/correct_and_smooth.py [0:0]


    def forward(self, g, labels, mask=None, post_step=lambda y: y.clamp_(0.0, 1.0)):
        with g.local_scope():
            if labels.dtype == torch.long:
                labels = F.one_hot(labels.view(-1)).to(torch.float32)

            y = labels
            if mask is not None:
                y = torch.zeros_like(labels)
                y[mask] = labels[mask]

            last = (1 - self.alpha) * y
            degs = g.in_degrees().float().clamp(min=1)
            norm = (
                torch.pow(degs, -0.5 if self.adj == "DAD" else -1)
                .to(labels.device)
                .unsqueeze(1)
            )

            for _ in range(self.num_layers):
                # Assume the graphs to be undirected
                if self.adj in ["DAD", "AD"]:
                    y = norm * y

                g.ndata["h"] = y
                g.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
                y = self.alpha * g.ndata.pop("h")

                if self.adj in ["DAD", "DA"]:
                    y = y * norm

                y = post_step(last + y)

            return y