def main()

in assets/responsibleai/vision/components/src/rai_vision_insights.py [0:0]


def main(args):
    _logger.info(f"Dataset type is {args.dataset_type}")

    # Load evaluation dataset
    _logger.info("Loading evaluation dataset")
    test_df, is_test_mltable_from_datastore = load_dataset(
        args.test_dataset,
        args.dataset_type
    )

    if args.model_input is None or args.model_info is None:
        raise ValueError(
            "Both model info and model input need to be supplied."
        )

    model_id = args.model_info
    _logger.info(f"Loading model: {model_id}")

    tracking_uri = mlflow.get_tracking_uri()
    model_serializer = ModelSerializer(
        model_id,
        model_type=args.model_type,
        use_model_dependency=args.use_model_dependency,
        use_conda=args.use_conda,
        tracking_uri=tracking_uri)
    image_model = model_serializer.load("Ignored path")

    if args.task_type == TaskType.MULTILABEL_IMAGE_CLASSIFICATION:
        target_column = get_from_args(
            args=args,
            arg_name="target_column_name",
            custom_parser=json_empty_is_none_parser,
            allow_none=True
        )
    else:
        target_column = args.target_column_name

    image_width_in_inches = args.image_width_in_inches
    max_evals = args.max_evals
    num_masks = args.num_masks
    mask_res = args.mask_res

    feature_metadata = FeatureMetadata()
    feature_metadata.categorical_features = get_from_args(
        args=args,
        arg_name="categorical_metadata_features",
        custom_parser=json_empty_is_none_parser,
        allow_none=True
    )
    feature_metadata.dropped_features = get_from_args(
        args=args,
        arg_name="dropped_metadata_features",
        custom_parser=json_empty_is_none_parser,
        allow_none=True
    )

    test_data_path = None
    image_downloader = None
    if is_test_mltable_from_datastore:
        test_data_path = args.test_dataset
        image_downloader = DownloadManager
        # Need to drop extra "portable path" column added by DownloadManager
        # to exclude it from feature metadata
        if feature_metadata.dropped_features is None:
            feature_metadata.dropped_features = [PORTABLE_PATH]
        else:
            feature_metadata.dropped_features.append(PORTABLE_PATH)
    _logger.info("Creating RAI Vision Insights")

    device_id = args.device
    device = Device.CPU.value
    is_gpu = False
    if device_id >= 0:
        device = Device.CUDA.value
        is_gpu = True
        is_available = str(torch.cuda.is_available())
        _logger.info("Using CUDA device. Is CUDA available: " + is_available)

    rai_vi = RAIVisionInsights(
        model=image_model,
        test=test_df,
        target_column=target_column,
        task_type=args.task_type,
        classes=get_from_args(
            args=args,
            arg_name="classes",
            custom_parser=json_empty_is_none_parser,
            allow_none=True
        ),
        maximum_rows_for_test=args.maximum_rows_for_test_dataset,
        serializer=model_serializer,
        test_data_path=test_data_path,
        image_downloader=image_downloader,
        feature_metadata=feature_metadata,
        image_width=image_width_in_inches,
        max_evals=max_evals,
        num_masks=num_masks,
        mask_res=mask_res,
        device=device
    )

    included_tools: Dict[str, bool] = {
        RAIToolType.CAUSAL: False,
        RAIToolType.COUNTERFACTUAL: False,
        RAIToolType.ERROR_ANALYSIS: False,
        RAIToolType.EXPLANATION: False,
    }

    if args.precompute_explanation:
        _logger.info("Adding explanation")
        rai_vi.explainer.add()
        included_tools[RAIToolType.EXPLANATION] = True

    if args.enable_error_analysis:
        _logger.info("Adding error analysis")
        rai_vi.error_analysis.add()
        included_tools[RAIToolType.ERROR_ANALYSIS] = True

    _logger.info("Starting computation")
    # get keyword arguments
    automl_xai_args = dict()
    for xai_arg in ExplainabilityLiterals.XAI_ARGS_GROUP:
        xai_arg_value = getattr(args, xai_arg)
        if xai_arg_value is not None:
            automl_xai_args[xai_arg] = xai_arg_value
        _logger.info(f"xai_arg: {xai_arg}, value: {xai_arg_value}")

    _logger.info(f"automl_xai_args: {automl_xai_args}")

    rai_vi.compute(**automl_xai_args)
    _logger.info("Computation complete")

    rai_vi.save(args.dashboard)
    _logger.info("Saved dashboard to output")

    rai_data = rai_vi.get_data()
    rai_dict = serialize_json_safe(rai_data)
    json_filename = "dashboard.json"
    output_path = Path(args.ux_json) / json_filename
    with open(output_path, "w") as json_file:
        json.dump(rai_dict, json_file)
    _logger.info("Dashboard JSON written")

    dashboard_info = dict()
    dashboard_info[DashboardInfo.RAI_INSIGHTS_MODEL_ID_KEY] = model_id

    add_properties_to_gather_run(dashboard_info, included_tools, is_gpu)
    _logger.info("Processing completed")