in seal_link_pred.py [0:0]
def test_multiple_models(models):
for m in models:
m.eval()
y_pred, y_true = [[] for _ in range(len(models))], [[] for _ in range(len(models))]
for data in tqdm(val_loader, ncols=70):
data = data.to(device)
x = data.x if args.use_feature else None
edge_weight = data.edge_weight if args.use_edge_weight else None
node_id = data.node_id if emb else None
for i, m in enumerate(models):
logits = m(data.z, data.edge_index, data.batch, x, edge_weight, node_id)
y_pred[i].append(logits.view(-1).cpu())
y_true[i].append(data.y.view(-1).cpu().to(torch.float))
val_pred = [torch.cat(y_pred[i]) for i in range(len(models))]
val_true = [torch.cat(y_true[i]) for i in range(len(models))]
pos_val_pred = [val_pred[i][val_true[i]==1] for i in range(len(models))]
neg_val_pred = [val_pred[i][val_true[i]==0] for i in range(len(models))]
y_pred, y_true = [[] for _ in range(len(models))], [[] for _ in range(len(models))]
for data in tqdm(test_loader, ncols=70):
data = data.to(device)
x = data.x if args.use_feature else None
edge_weight = data.edge_weight if args.use_edge_weight else None
node_id = data.node_id if emb else None
for i, m in enumerate(models):
logits = m(data.z, data.edge_index, data.batch, x, edge_weight, node_id)
y_pred[i].append(logits.view(-1).cpu())
y_true[i].append(data.y.view(-1).cpu().to(torch.float))
test_pred = [torch.cat(y_pred[i]) for i in range(len(models))]
test_true = [torch.cat(y_true[i]) for i in range(len(models))]
pos_test_pred = [test_pred[i][test_true[i]==1] for i in range(len(models))]
neg_test_pred = [test_pred[i][test_true[i]==0] for i in range(len(models))]
Results = []
for i in range(len(models)):
if args.eval_metric == 'hits':
Results.append(evaluate_hits(pos_val_pred[i], neg_val_pred[i],
pos_test_pred[i], neg_test_pred[i]))
elif args.eval_metric == 'mrr':
Results.append(evaluate_mrr(pos_val_pred[i], neg_val_pred[i],
pos_test_pred[i], neg_test_pred[i]))
elif args.eval_metric == 'auc':
Results.append(evaluate_auc(val_pred[i], val_true[i],
test_pred[i], test_pred[i]))
return Results