mico/evaluate/evaluate.py (122 lines of code) (raw):
import torch
from tqdm import tqdm
import numpy as np
import json
import logging
import os
from timeit import default_timer as timer
from datetime import timedelta
from mico.dataloader import QueryDocumentsPair
from mico.utils import monitor_metrics
def evaluate(model, eval_loader, num_batches=-1, device=None):
"""This function evaluates the model during training.
Parameters
----------
model : MutualInfoCotrain object
This is the MICO model we have been training.
eval_loader : PyTorch Dataloader object
This is the dataloader used for evaluation. It will be val_dataloader and test_dataloader.
num_batches : int
If this is set to be a non-negative number,
the evaluation will only be done with the first n batches in the dataloader where n = num_batches.
device : int (for multi-GPU) or string ('cpu' or 'cuda')
This is the GPU index since we use DistributedDataParallel,
and the index of the GPU card is also the index of the sub-process.
Returns
-------
perf : dictionary
This performance metrics are in a dictionary:
{'AUC': area under the curve (search cost v.s. top-1 coverage) in percentage;
'top1_cov': top-1 coverage in percentage,
in how much proportion of the data (query-document pair), the top-1 cluster for query routing
and the top-1 cluster for document assignment are the same cluster;
'H__Z_X': the entropy of cluster distribution of query routing (to their top-1 cluster) in this evaluation data;
'H__Z_Y': the entropy of cluster distribution of document assigment (to their top-1 cluster) in this evaluation data;
}
"""
model.eval()
with torch.no_grad():
encoding_chunks_query = []
encoding_chunks_document = []
for batch_idx, (query, document) in enumerate(tqdm(eval_loader)):
if batch_idx == num_batches:
break
query_prob = model.forward(query=query, forward_method="encode_query", device=device)
document_prob = model.forward(document=document, forward_method="encode_doc", device=device)
encoding_chunks_query.append(query_prob)
encoding_chunks_document.append(document_prob)
encodings_query = torch.cat(encoding_chunks_query, 0).float()
encodings_document = torch.cat(encoding_chunks_document, 0).float()
perf = monitor_metrics(encodings_query, encodings_document)
model.train()
return perf
def infer_on_test(model, device=None):
"""This function will assign the documents in all the three datasets (train/val/test), which is saved in `clustered_docs.json`,
and calculate the metrics for the performance on testing data, which is saved in `metrics.json`.
The metrics are the coverage for impression, click, purchase; the latency cost and total search cost; and the cluster sizes.
Parameters
----------
model : MutualInfoCotrain object
This is the MICO model we have loaded. The path for the train/val/test dataset is in `model.hparams`.
device : string ('cpu' or 'cuda')
This is the device we run model on.
Returns
-------
metrics :
This contains the coverage for impression, click, purchase, also the latency cost and total search cost on test dataset.
This also contains the cluster sizes.
"""
start = timer()
hparams = model.hparams
data = QueryDocumentsPair(train_folder_path=hparams.train_folder_path, test_folder_path=hparams.test_folder_path,
is_csv_header=hparams.is_csv_header, val_ratio=hparams.val_ratio, is_get_all_info=True)
train_loader, val_loader, test_loader = data.get_loaders(model.hparams.batch_size_test, model.hparams.num_workers,
is_shuffle_train=False, is_get_test=True)
model.eval()
cluster_docs_path = hparams.model_path + '/clustered_docs.json'
metrics_save_path = hparams.model_path + '/metrics.json'
cluster_docs = {}
n_clusters = hparams.number_clusters
for i in range(n_clusters):
cluster_docs[i] = {}
with torch.no_grad():
if os.path.isfile(cluster_docs_path):
# documents are already clustered for this model.
logging.info("Loading the existing document clustering results.")
logging.info("If you want to re-cluster the documents, please delete %s" % cluster_docs_path)
with open(cluster_docs_path, "r+") as f:
cluster_docs = json.load(f)
cluster_docs2 = {}
for i in cluster_docs:
cluster_docs2[int(i)] = cluster_docs[i]
cluster_docs = cluster_docs2
else:
# cluster all the documents
for (query, asin, document, click, purchase) in tqdm(train_loader):
document_prob = model.forward(document=document, forward_method="encode_doc", device=device)
for idx2, class_idx in enumerate(torch.argmax(document_prob, dim=1).cpu().numpy()):
cluster_docs[class_idx][asin[idx2]] = 1
for (query, asin, document, click, purchase) in tqdm(val_loader):
document_prob = model.forward(document=document, forward_method="encode_doc", device=device)
for idx2, class_idx in enumerate(torch.argmax(document_prob, dim=1).cpu().numpy()):
cluster_docs[class_idx][asin[idx2]] = 1
for (query, asin, document, click, purchase) in tqdm(test_loader):
document_prob = model.forward(document=document, forward_method="encode_doc", device=device)
for idx2, class_idx in enumerate(torch.argmax(document_prob, dim=1).cpu().numpy()):
cluster_docs[class_idx][asin[idx2]] = 1
logging.info('Save the clusters of all the documents ...')
with open(cluster_docs_path, 'w+') as f:
json.dump(cluster_docs, f, indent='\t')
logging.info('Calculate the metrics (coverage for impression, click, and purchase; latency cost, and total search cost) on test dataset ...')
cluster_sizes = [len(cluster_docs[i]) for i in range(n_clusters)]
impression_coverage = np.zeros(n_clusters)
click_coverage = np.zeros(n_clusters)
purchase_coverage = np.zeros(n_clusters)
latency_cost_curve = np.zeros(n_clusters)
total_cost_curve = np.zeros(n_clusters)
for (query_batch, asin_batch, document_batch, click_batch, purchase_batch) in tqdm(test_loader):
query_batch_prob = model.forward(query=query_batch, forward_method="encode_query", device=device)
query_batch_prob = -query_batch_prob.detach().cpu().numpy()
query_batch_cluster_rank = query_batch_prob.argsort(axis=1)
for query_idx in range(len(query_batch)): # this is a batch
impression_find = np.zeros(n_clusters)
click_find = np.zeros(n_clusters)
purchase_find = np.zeros(n_clusters)
total_time_find = np.zeros(n_clusters)
latency_find = np.zeros(n_clusters)
query_one = query_batch[query_idx]
asin_one = asin_batch[query_idx]
click_one = click_batch[query_idx]
purchase_one = purchase_batch[query_idx]
for rank_idx, cluster_idx in enumerate(query_batch_cluster_rank[query_idx]):
total_time_find[rank_idx] = cluster_sizes[cluster_idx]
latency_find[rank_idx] = max(latency_find[rank_idx-1], cluster_sizes[cluster_idx]) \
if rank_idx > 0 else cluster_sizes[cluster_idx]
if asin_one in cluster_docs[cluster_idx]:
impression_find[rank_idx] = 1
click_find[rank_idx] = click_one
purchase_find[rank_idx] = purchase_one
found_asin = True
if not found_asin:
logging.info('ERROR: ASIN %s not found in all clusters when consider query %s' % (asin_one, query_one))
raise ValueError
impression_coverage += np.cumsum(impression_find)
click_coverage += np.cumsum(click_find)
purchase_coverage += np.cumsum(purchase_find)
total_cost_curve += np.cumsum(total_time_find)
latency_cost_curve += latency_find
impression_coverage /= impression_coverage[-1]
click_coverage /= click_coverage[-1]
purchase_coverage /= purchase_coverage[-1]
total_cost_curve = total_cost_curve / len(data.test_dataset)
latency_cost_curve = latency_cost_curve / len(data.test_dataset)
metrics = { 'impression_coverage': impression_coverage.tolist(), \
'click_coverage': click_coverage.tolist(), \
'purchase_coverage': purchase_coverage.tolist(), \
'total_cost_curve': total_cost_curve.tolist(), \
'latency_cost_curve': latency_cost_curve.tolist(), \
'cluster_sizes': sorted(cluster_sizes), \
}
logging.info('Save the metrics ...')
with open(metrics_save_path, 'w+') as f:
json.dump(metrics, f, indent='\t')
logging.info("Total testing time: %s" % str(timedelta(seconds=round(timer() - start))))
return metrics