def __call__()

in hugegraph-ml/src/hugegraph_ml/models/seal.py [0:0]


    def __call__(self, split_type):
        if split_type == "train":
            subsample_ratio = self.subsample_ratio
        else:
            subsample_ratio = 1

        pos_edges = self.split_edge[split_type]["edge"]
        if split_type == "train":
            # Adding self loop in train avoids sampling the source node itself.
            g = add_self_loop(self.g)
            eids = g.edge_ids(pos_edges[:, 0], pos_edges[:, 1])
            neg_edges = torch.stack(self.neg_sampler(g, eids), dim=1)
        else:
            neg_edges = self.split_edge[split_type]["edge_neg"]
        pos_edges = self.subsample(pos_edges, subsample_ratio).long()
        neg_edges = self.subsample(neg_edges, subsample_ratio).long()

        edges = torch.cat([pos_edges, neg_edges])
        labels = torch.cat(
            [
                torch.ones(pos_edges.size(0), 1),
                torch.zeros(neg_edges.size(0), 1),
            ]
        )
        if self.shuffle:
            perm = torch.randperm(edges.size(0))
            edges = edges[perm]
            labels = labels[perm]
        return edges, labels