in hugegraph-ml/src/hugegraph_ml/models/pgnn.py [0:0]
def get_loss(p, data, out, loss_func, device, get_auc=True):
edge_mask = np.concatenate(
(
data[f"positive_edges_{p}"],
data[f"negative_edges_{p}"],
),
axis=-1,
)
nodes_first = torch.index_select(
out, 0, torch.from_numpy(edge_mask[0, :]).long().to(out.device)
)
nodes_second = torch.index_select(
out, 0, torch.from_numpy(edge_mask[1, :]).long().to(out.device)
)
pred = torch.sum(nodes_first * nodes_second, dim=-1)
label_positive = torch.ones(
[
data[f"positive_edges_{p}"].shape[1],
],
dtype=pred.dtype,
)
label_negative = torch.zeros(
[
data[f"negative_edges_{p}"].shape[1],
],
dtype=pred.dtype,
)
label = torch.cat((label_positive, label_negative)).to(device)
loss = loss_func(pred, label)
if get_auc:
auc = roc_auc_score(
label.flatten().cpu().numpy(),
torch.sigmoid(pred).flatten().data.cpu().numpy(),
)
return loss, auc
else:
return loss