in source/sagemaker/sagemaker_graph_entity_resolution/dgl_entity_resolution/train_dgl_pytorch_entity_resolution.py [0:0]
def evaluate(model, g, train_triplets, test_triplets, user_features, web_features, batch_size, n_neighbors,
hits=[1, 3, 10], device=None, filtered=False, mean_ap=False):
logging.info("Performing model inference to get embeddings")
embed = model.inference(g, user_features, web_features, batch_size, n_neighbors, device)
logging.info("Got embeddings, computing metrics")
w = model.w_relation.detach().clone().cpu()
if mean_ap:
metric = calc_mAP(embed, w, train_triplets, test_triplets)
else:
if filtered:
metric = calc_filtered_mrr(embed, w, train_triplets, test_triplets, hits=hits)
else:
metric = calc_raw_mrr(embed, w, test_triplets, hits=hits, eval_bz=10000)
return metric