def get_loss()

in hugegraph-ml/src/hugegraph_ml/models/pgnn.py [0:0]


def get_loss(p, data, out, loss_func, device, get_auc=True):
    edge_mask = np.concatenate(
        (
            data[f"positive_edges_{p}"],
            data[f"negative_edges_{p}"],
        ),
        axis=-1,
    )

    nodes_first = torch.index_select(
        out, 0, torch.from_numpy(edge_mask[0, :]).long().to(out.device)
    )
    nodes_second = torch.index_select(
        out, 0, torch.from_numpy(edge_mask[1, :]).long().to(out.device)
    )

    pred = torch.sum(nodes_first * nodes_second, dim=-1)

    label_positive = torch.ones(
        [
            data[f"positive_edges_{p}"].shape[1],
        ],
        dtype=pred.dtype,
    )
    label_negative = torch.zeros(
        [
            data[f"negative_edges_{p}"].shape[1],
        ],
        dtype=pred.dtype,
    )
    label = torch.cat((label_positive, label_negative)).to(device)
    loss = loss_func(pred, label)

    if get_auc:
        auc = roc_auc_score(
            label.flatten().cpu().numpy(),
            torch.sigmoid(pred).flatten().data.cpu().numpy(),
        )
        return loss, auc
    else:
        return loss