in mm_action_prediction/eval_simmc_agent.py [0:0]
def evaluate_agent(wizard, val_loader, args):
"""Evaluate a SIMMC agent given a dataloader.
Args:
wizard: SIMMC model
dataloader: Dataloader to use to run the model on
args: Arguments for evaluation
"""
total_iters = int(val_loader.num_instances / args["batch_size"])
# Turn autograd off for evaluation -- light-weight and faster.
with torch.no_grad():
wizard.eval()
matches = []
for batch in progressbar(val_loader.get_batch(), total=int(total_iters)):
if args["bleu_evaluation"]:
mode = {"next_token": "ARGMAX", "beam_size": 5}
else:
mode = None
batch_outputs = wizard(batch, mode)
# Stringify model responses.
if args["bleu_evaluation"]:
batch_outputs["model_response"] = (
val_loader.stringify_beam_outputs(
batch_outputs["beam_output"], batch
)
)
# Remove beam output to avoid memory issues.
del batch_outputs["beam_output"]
matches.append(batch_outputs)
wizard.train()
# Compute perplexity.
total_loss_sum = sum(ii["loss_sum"].item() for ii in matches)
num_tokens = sum(ii["num_tokens"].item() for ii in matches)
avg_loss_eval = total_loss_sum / num_tokens
# Compute BLEU score.
model_responses = None
bleu_score = -1.
if args["bleu_evaluation"]:
model_responses = [jj for ii in matches for jj in ii["model_response"]]
# Save the JSON file.
if args.get("save_model_output", False):
save_path = args["checkpoint"].replace(".tar", "_response_gen.json")
with open(save_path, "w") as file_id:
json.dump(model_responses, file_id)
else:
bleu_score = val_loader.evaluate_response_generation(model_responses)
# Evaluate retrieval score.
retrieval_metrics = {}
if args["retrieval_evaluation"]:
candidate_scores = [jj for ii in matches for jj in ii["candidate_scores"]]
# Save the JSON file.
if args.get("save_model_output", False):
save_path = args["checkpoint"].replace(".tar", "_response_ret.json")
with open(save_path, "w") as file_id:
json.dump(candidate_scores, file_id)
else:
retrieval_metrics = val_loader.evaluate_response_retrieval(
candidate_scores
)
print(retrieval_metrics)
# Evaluate action prediction.
action_predictions = [jj for ii in matches for jj in ii["action_preds"]]
# Save the JSON file.
if args.get("save_model_output", False):
save_path = args["checkpoint"].replace(".tar", "_action_gen.json")
with open(save_path, "w") as file_id:
json.dump(action_predictions, file_id)
action_metrics = val_loader.evaluate_action_prediction(action_predictions)
print(action_metrics["confusion_matrix"])
print_str = (
"\nEvaluation\n\tLoss: {:.2f}\n\t"
"Perplexity: {:.2f}\n\tBLEU: {:.3f}\n\t"
"Action: {:.2f}\n\t"
"Action Perplexity: {:.2f}\n\t"
"Action Attribute Accuracy: {:.2f}"
)
print(
print_str.format(
avg_loss_eval,
math.exp(avg_loss_eval),
bleu_score,
100 * action_metrics["action_accuracy"],
action_metrics["action_perplexity"],
100 * action_metrics["attribute_accuracy"]
)
)
# Save the results to a file.
eval_dict = {
"loss": avg_loss_eval,
"perplexity": math.exp(avg_loss_eval),
"bleu": bleu_score,
"action_accuracy": action_metrics["action_accuracy"],
"action_perplexity": action_metrics["action_perplexity"],
"action_attribute": action_metrics["attribute_accuracy"]
}
eval_dict.update(retrieval_metrics)
eval_outputs = {
"model_actions": action_predictions,
"model_responses": model_responses
}
return eval_dict, eval_outputs