def run()

in assets/training/model_evaluation/src/compute_metrics.py [0:0]


def run():
    """Entry point for compute metrics component."""
    parser = ArgumentParser()
    # Inputs
    parser.add_argument("--task", type=str, dest=ArgumentLiterals.TASK, choices=constants.ALL_TASKS)
    parser.add_argument("--ground_truths", type=str, dest=ArgumentLiterals.GROUND_TRUTHS, required=False)
    parser.add_argument("--ground_truths_column_name", type=lambda x: x.split(","),
                        dest=ArgumentLiterals.GROUND_TRUTHS_COLUMN_NAME, required=False, default=None)
    parser.add_argument("--predictions", type=str, dest=ArgumentLiterals.PREDICTIONS, required=True)
    parser.add_argument("--predictions_column_name", type=str,
                        dest=ArgumentLiterals.PREDICTIONS_COLUMN_NAME, required=False, default=None)
    parser.add_argument("--prediction_probabilities", type=str, dest=ArgumentLiterals.PREDICTION_PROBABILITIES,
                        required=False, default="")
    parser.add_argument("--config_str", type=str, dest=ArgumentLiterals.CONFIG_STR, required=False, default=None)
    parser.add_argument("--config-file-name", type=str, dest=ArgumentLiterals.CONFIG_FILE_NAME,
                        required=False, default="")
    parser.add_argument("--openai-config-params", type=str, dest=ArgumentLiterals.OPENAI_CONFIG_PARAMS,
                        required=False, default="{}")

    # Outputs
    parser.add_argument("--output", type=str, dest=ArgumentLiterals.OUTPUT)

    args, _ = parser.parse_known_args()
    args = vars(args)

    with log_activity(logger, constants.TelemetryConstants.VALIDATION_NAME,
                      custom_dimensions=custom_dims_dict):
        logger.info("Validating arguments: " + repr(args))
        validate_common_args(args)
        validate_compute_metrics_args(args)
        config = fetch_compute_metrics_args(args[ArgumentLiterals.CONFIG], args[ArgumentLiterals.TASK])

        ground_truths_column_name, extra_y_test_cols = parse_input_ground_truth_col(
            args[ArgumentLiterals.GROUND_TRUTHS_COLUMN_NAME]
        )

        ground_truths = args[ArgumentLiterals.GROUND_TRUTHS]
        is_ground_truths_mltable = check_and_return_if_mltable(ground_truths)
        predictions = args[ArgumentLiterals.PREDICTIONS]
        is_predictions_mltable = check_and_return_if_mltable(predictions)
        prediction_probabilities = args[ArgumentLiterals.PREDICTION_PROBABILITIES]
        is_prediction_probabilities_mltable = check_and_return_if_mltable(
            args[ArgumentLiterals.PREDICTION_PROBABILITIES]
        )

        if args[ArgumentLiterals.TASK] == TASK.FORECASTING:
            if not ground_truths_column_name:
                # If the ground true column name was not provided, we will try to take it from the config.
                ground_truths_column_name = config.get(ForecastingConfigContract.FORECAST_GROUND_TRUTH, '')
            if not args[ArgumentLiterals.PREDICTIONS_COLUMN_NAME]:
                args[ArgumentLiterals.PREDICTIONS_COLUMN_NAME] = config.get(
                    ForecastingConfigContract.FORECAST_PREDICTIONS, '')
            if not ground_truths_column_name or (
                    not is_ground_truths_mltable and not args[ArgumentLiterals.GROUND_TRUTHS]
            ):
                exception = get_azureml_exception(DataValidationException, BadForecastData, None)
                log_traceback(exception, logger)
                raise exception
        llm_config = {}
        try:
            llm_config = json.loads(args[ArgumentLiterals.OPENAI_CONFIG_PARAMS])
            if not isinstance(llm_config, dict):
                raise ValueError(f"Openai config is expected to be a json encoded dictionary. \
                                 Got {type(llm_config)} instead")
        except Exception as e:
            logger.warning(f"Loading OpenAI Config Params Failed with exception {repr(e)}.\n\
                           Skipping GPT Based Metrics Calculation.")

    logger.info(f"LLM Config {llm_config}")
    with log_activity(logger, constants.TelemetryConstants.INITIALISING_RUNNER,
                      custom_dimensions=custom_dims_dict):
        runner = ComputeMetricsRunner(
            task=args[ArgumentLiterals.TASK],
            ground_truth=ground_truths,
            predictions=predictions,
            prediction_probabilities=prediction_probabilities,
            output=args[ArgumentLiterals.OUTPUT],
            config=config,
            is_ground_truth_mltable=is_ground_truths_mltable,
            is_predictions_mltable=is_predictions_mltable,
            is_prediction_probabilities_mltable=is_prediction_probabilities_mltable,
            ground_truths_column_name=ground_truths_column_name,
            predictions_column_name=args[ArgumentLiterals.PREDICTIONS_COLUMN_NAME],
            extra_y_test_cols=extra_y_test_cols,
            llm_config=llm_config
        )

    with log_activity(logger, activity_name=constants.TelemetryConstants.DATA_LOADING,
                      custom_dimensions=custom_dims_dict):
        try:
            logger.info("Loading Data.")
            flush_logger(logger)
            runner.load_data()
        except Exception as e:
            exception = get_azureml_exception(DataLoaderException, BadInputData, e, error=repr(e))
            log_traceback(exception, logger)
            raise exception

    with log_activity(logger, activity_name=constants.TelemetryConstants.COMPUTE_METRICS_NAME,
                      custom_dimensions=custom_dims_dict):
        logger.info("Computing metrics.")
        result = runner.compute_metrics()

    with log_activity(logger, activity_name=constants.TelemetryConstants.LOG_AND_SAVE_OUTPUT,
                      custom_dimensions=custom_dims_dict):
        logger.info("Logging and Saving outputs.")
        try:
            runner.log_and_write_outputs(result)
        except Exception as e:
            exception = get_azureml_exception(DataSavingException, SavingOutputError, e, error=repr(e))
            log_traceback(exception, logger)
            raise exception

    test_run.add_properties(properties=constants.RUN_PROPERTIES)
    try:
        root_run.add_properties(properties=constants.ROOT_RUN_PROPERTIES)
    except Exception:
        logger.info("PipelineType is already a property at Root Pipeline Run.")
    test_run.complete()
    return