in seal_link_pred.py [0:0]
def test():
model.eval()
y_pred, y_true = [], []
for data in tqdm(val_loader, ncols=70):
data = data.to(device)
x = data.x if args.use_feature else None
edge_weight = data.edge_weight if args.use_edge_weight else None
node_id = data.node_id if emb else None
logits = model(data.z, data.edge_index, data.batch, x, edge_weight, node_id)
y_pred.append(logits.view(-1).cpu())
y_true.append(data.y.view(-1).cpu().to(torch.float))
val_pred, val_true = torch.cat(y_pred), torch.cat(y_true)
pos_val_pred = val_pred[val_true==1]
neg_val_pred = val_pred[val_true==0]
y_pred, y_true = [], []
for data in tqdm(test_loader, ncols=70):
data = data.to(device)
x = data.x if args.use_feature else None
edge_weight = data.edge_weight if args.use_edge_weight else None
node_id = data.node_id if emb else None
logits = model(data.z, data.edge_index, data.batch, x, edge_weight, node_id)
y_pred.append(logits.view(-1).cpu())
y_true.append(data.y.view(-1).cpu().to(torch.float))
test_pred, test_true = torch.cat(y_pred), torch.cat(y_true)
pos_test_pred = test_pred[test_true==1]
neg_test_pred = test_pred[test_true==0]
if args.eval_metric == 'hits':
results = evaluate_hits(pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred)
elif args.eval_metric == 'mrr':
results = evaluate_mrr(pos_val_pred, neg_val_pred, pos_test_pred, neg_test_pred)
elif args.eval_metric == 'auc':
results = evaluate_auc(val_pred, val_true, test_pred, test_true)
return results