def evaluate_agent()

in mm_action_prediction/eval_simmc_agent.py [0:0]


def evaluate_agent(wizard, val_loader, args):
    """Evaluate a SIMMC agent given a dataloader.

    Args:
        wizard: SIMMC model
        dataloader: Dataloader to use to run the model on
        args: Arguments for evaluation
    """
    total_iters = int(val_loader.num_instances / args["batch_size"])
    # Turn autograd off for evaluation -- light-weight and faster.
    with torch.no_grad():
        wizard.eval()
        matches = []
        for batch in progressbar(val_loader.get_batch(), total=int(total_iters)):
            if args["bleu_evaluation"]:
                mode = {"next_token": "ARGMAX", "beam_size": 5}
            else:
                mode = None
            batch_outputs = wizard(batch, mode)
            # Stringify model responses.
            if args["bleu_evaluation"]:
                batch_outputs["model_response"] = (
                    val_loader.stringify_beam_outputs(
                        batch_outputs["beam_output"], batch
                    )
                )
                # Remove beam output to avoid memory issues.
                del batch_outputs["beam_output"]
            matches.append(batch_outputs)
    wizard.train()

    # Compute perplexity.
    total_loss_sum = sum(ii["loss_sum"].item() for ii in matches)
    num_tokens = sum(ii["num_tokens"].item() for ii in matches)
    avg_loss_eval = total_loss_sum / num_tokens

    # Compute BLEU score.
    model_responses = None
    bleu_score = -1.
    if args["bleu_evaluation"]:
        model_responses = [jj for ii in matches for jj in ii["model_response"]]
        # Save the JSON file.
        if args.get("save_model_output", False):
            save_path = args["checkpoint"].replace(".tar", "_response_gen.json")
            with open(save_path, "w") as file_id:
                json.dump(model_responses, file_id)
        else:
            bleu_score = val_loader.evaluate_response_generation(model_responses)

    # Evaluate retrieval score.
    retrieval_metrics = {}
    if args["retrieval_evaluation"]:
        candidate_scores = [jj for ii in matches for jj in ii["candidate_scores"]]
        # Save the JSON file.
        if args.get("save_model_output", False):
            save_path = args["checkpoint"].replace(".tar", "_response_ret.json")
            with open(save_path, "w") as file_id:
                json.dump(candidate_scores, file_id)
        else:
            retrieval_metrics = val_loader.evaluate_response_retrieval(
                candidate_scores
            )
            print(retrieval_metrics)

    # Evaluate action prediction.
    action_predictions = [jj for ii in matches for jj in ii["action_preds"]]
    # Save the JSON file.
    if args.get("save_model_output", False):
        save_path = args["checkpoint"].replace(".tar", "_action_gen.json")
        with open(save_path, "w") as file_id:
            json.dump(action_predictions, file_id)
    action_metrics = val_loader.evaluate_action_prediction(action_predictions)
    print(action_metrics["confusion_matrix"])
    print_str = (
        "\nEvaluation\n\tLoss: {:.2f}\n\t"
        "Perplexity: {:.2f}\n\tBLEU: {:.3f}\n\t"
        "Action: {:.2f}\n\t"
        "Action Perplexity: {:.2f}\n\t"
        "Action Attribute Accuracy: {:.2f}"
    )
    print(
        print_str.format(
            avg_loss_eval,
            math.exp(avg_loss_eval),
            bleu_score,
            100 * action_metrics["action_accuracy"],
            action_metrics["action_perplexity"],
            100 * action_metrics["attribute_accuracy"]
        )
    )
    # Save the results to a file.
    eval_dict = {
        "loss": avg_loss_eval,
        "perplexity": math.exp(avg_loss_eval),
        "bleu": bleu_score,
        "action_accuracy": action_metrics["action_accuracy"],
        "action_perplexity": action_metrics["action_perplexity"],
        "action_attribute": action_metrics["attribute_accuracy"]
    }
    eval_dict.update(retrieval_metrics)
    eval_outputs = {
        "model_actions": action_predictions,
        "model_responses": model_responses
    }
    return eval_dict, eval_outputs