in hugegraph-ml/src/hugegraph_ml/models/seal.py [0:0]
def __call__(self, split_type):
if split_type == "train":
subsample_ratio = self.subsample_ratio
else:
subsample_ratio = 1
pos_edges = self.split_edge[split_type]["edge"]
if split_type == "train":
# Adding self loop in train avoids sampling the source node itself.
g = add_self_loop(self.g)
eids = g.edge_ids(pos_edges[:, 0], pos_edges[:, 1])
neg_edges = torch.stack(self.neg_sampler(g, eids), dim=1)
else:
neg_edges = self.split_edge[split_type]["edge_neg"]
pos_edges = self.subsample(pos_edges, subsample_ratio).long()
neg_edges = self.subsample(neg_edges, subsample_ratio).long()
edges = torch.cat([pos_edges, neg_edges])
labels = torch.cat(
[
torch.ones(pos_edges.size(0), 1),
torch.zeros(neg_edges.size(0), 1),
]
)
if self.shuffle:
perm = torch.randperm(edges.size(0))
edges = edges[perm]
labels = labels[perm]
return edges, labels