def main()

in assets/responsibleai/text/components/src/rai_text_insights.py [0:0]


def main(args):
    _logger.info("Loading evaluation dataset")
    test_df = load_dataset(args.test_dataset)

    if args.model_input is None or args.model_info is None:
        raise ValueError(
            "Both model info and model input need to be supplied.")

    model_meta = mlflow.models.Model.load(
                    os.path.join(args.model_input, "MLmodel"))
    model_type = None
    if ModelTypes.HFTRANSFORMERS in model_meta.flavors:
        model_type = ModelTypes.HFTRANSFORMERS

    model_id = args.model_info
    _logger.info("Loading model: {0}".format(model_id))

    tracking_uri = mlflow.get_tracking_uri()
    model_serializer = ModelSerializer(
        model_id,
        mlflow_model=args.model_input,
        model_type=model_type,
        use_model_dependency=args.use_model_dependency,
        use_conda=args.use_conda,
        task_type=args.task_type,
        tracking_uri=tracking_uri)
    text_model = model_serializer.load("Ignored path")

    text_column = args.text_column_name

    if args.task_type == TaskType.MULTILABEL_TEXT_CLASSIFICATION:
        target_column = get_from_args(
            args=args,
            arg_name="target_column_name",
            custom_parser=json_empty_is_none_parser,
            allow_none=True
        )
    else:
        target_column = args.target_column_name

    feature_metadata = FeatureMetadata()
    feature_metadata.categorical_features = get_from_args(
        args=args,
        arg_name="categorical_metadata_features",
        custom_parser=json_empty_is_none_parser,
        allow_none=True
    )
    feature_metadata.dropped_features = get_from_args(
        args=args,
        arg_name="dropped_metadata_features",
        custom_parser=json_empty_is_none_parser,
        allow_none=True
    )

    _logger.info("Creating RAI Text Insights")
    rai_ti = RAITextInsights(
        model=text_model,
        test=test_df,
        target_column=target_column,
        task_type=args.task_type,
        classes=get_from_args(
            args, "classes", custom_parser=json_empty_is_none_parser,
            allow_none=True
        ),
        maximum_rows_for_test=args.maximum_rows_for_test_dataset,
        serializer=model_serializer,
        feature_metadata=feature_metadata,
        text_column=text_column
    )

    included_tools: Dict[str, bool] = {
        RAIToolType.CAUSAL: False,
        RAIToolType.COUNTERFACTUAL: False,
        RAIToolType.ERROR_ANALYSIS: False,
        RAIToolType.EXPLANATION: False,
    }

    if args.enable_explanation:
        _logger.info("Adding explanation")
        rai_ti.explainer.add()
        included_tools[RAIToolType.EXPLANATION] = True

    if args.enable_error_analysis:
        _logger.info("Adding error analysis")
        rai_ti.error_analysis.add()
        included_tools[RAIToolType.ERROR_ANALYSIS] = True

    _logger.info("Starting computation")
    rai_ti.compute()
    _logger.info("Computation complete")

    rai_ti.save(args.dashboard)
    _logger.info("Saved dashboard to output")

    rai_data = rai_ti.get_data()
    rai_dict = serialize_json_safe(rai_data)
    json_filename = "dashboard.json"
    output_path = Path(args.ux_json) / json_filename
    with open(output_path, "w") as json_file:
        json.dump(rai_dict, json_file)
    _logger.info("Dashboard JSON written")

    dashboard_info = dict()
    dashboard_info[DashboardInfo.RAI_INSIGHTS_MODEL_ID_KEY] = model_id

    add_properties_to_gather_run(dashboard_info, included_tools)
    _logger.info("Processing completed")