def evaluate()

in source/sagemaker/sagemaker_graph_entity_resolution/dgl_entity_resolution/train_dgl_pytorch_entity_resolution.py [0:0]


def evaluate(model, g, train_triplets, test_triplets, user_features, web_features, batch_size, n_neighbors,
             hits=[1, 3, 10], device=None, filtered=False, mean_ap=False):

    logging.info("Performing model inference to get embeddings")
    embed = model.inference(g, user_features, web_features, batch_size, n_neighbors, device)
    logging.info("Got embeddings, computing metrics")

    w = model.w_relation.detach().clone().cpu()

    if mean_ap:
        metric = calc_mAP(embed, w, train_triplets, test_triplets)
    else:
        if filtered:
            metric = calc_filtered_mrr(embed, w, train_triplets, test_triplets, hits=hits)
        else:
            metric = calc_raw_mrr(embed, w, test_triplets, hits=hits, eval_bz=10000)
    return metric