def main()

in single-rai-job/src/run_rai.py [0:0]


def main(args):
    _logger.info(f"responsibleai=={responsibleai_version}")
    my_run = Run.get_context()
    my_ws = my_run.experiment.workspace

    _logger.info("Dealing with initialization dataset")
    train_df = load_tabular_dataset(args.train_dataset_id, my_ws)

    _logger.info("Dealing with evaluation dataset")
    test_df = load_tabular_dataset(args.test_dataset_id, my_ws)

    _logger.info("Loading model: {0}".format(args.model_id))
    model_estimator = load_mlflow_model(my_ws, args.model_id)

    constructor_args = create_constructor_arg_dict(args)

    # Create the RAIInsights object
    _logger.info("Creating RAIInsights object")
    rai_i = RAIInsights(
        model=model_estimator, train=train_df, test=test_df, **constructor_args
    )

    included_tools: Dict[str, bool] = {
        RAIToolType.CAUSAL: False,
        RAIToolType.COUNTERFACTUAL: False,
        RAIToolType.ERROR_ANALYSIS: False,
        RAIToolType.EXPLANATION: False,
    }

    _logger.info("Checking individual tools")

    if args.enable_causal:
        _logger.info("Adding causal")
        _logger.info(f"treatement_features: {args.causal_treatment_features}")
        _logger.info(f"heterogeneity_features: {args.causal_heterogeneity_features}")
        _logger.info(f"nuisance_model: {args.causal_nuisance_model}")
        _logger.info(f"heterogeneity_model: {args.causal_heterogeneity_model}")
        _logger.info(f"alpha: {args.causal_alpha}")
        _logger.info(
            f"upper_bound_on_cat_expansion: {args.causal_upper_bound_on_cat_expansion}"
        )
        _logger.info(f"treatment_cost: {args.causal_treatment_cost}")
        _logger.info(f"min_tree_leaf_samples: {args.causal_min_tree_leaf_samples}")
        _logger.info(f"max_tree_depth: {args.causal_max_tree_depth}")
        _logger.info(f"skip_cat_limit_checks: {args.causal_skip_cat_limit_checks}")
        _logger.info(f"categories: {args.causal_categories}")
        _logger.info(f"n_jobs: {args.causal_n_jobs}")
        _logger.info(f"verbose: {args.causal_verbose}")
        _logger.info(f"random_state: {args.causal_random_state}")
        rai_i.causal.add(
            treatment_features=args.causal_treatment_features,
            heterogeneity_features=args.causal_heterogeneity_features,
            nuisance_model=args.causal_nuisance_model,
            heterogeneity_model=args.causal_heterogeneity_model,
            alpha=args.causal_alpha,
            upper_bound_on_cat_expansion=args.causal_upper_bound_on_cat_expansion,
            treatment_cost=args.causal_treatment_cost,
            min_tree_leaf_samples=args.causal_min_tree_leaf_samples,
            max_tree_depth=args.causal_max_tree_depth,
            skip_cat_limit_checks=args.causal_skip_cat_limit_checks,
            categories=args.causal_categories,
            n_jobs=args.causal_n_jobs,
            verbose=args.causal_verbose,
            random_state=args.causal_random_state,
        )
        included_tools[RAIToolType.CAUSAL] = True

    if args.enable_counterfactual:
        _logger.info("Adding counterfactuals")
        _logger.info(f"total_CFs: {args.counterfactual_total_CFs}")
        _logger.info(f"method: {args.counterfactual_method}")
        _logger.info(f"desired_class: {args.counterfactual_desired_class}")
        _logger.info(f"desired_range: {args.counterfactual_desired_range}")
        _logger.info(f"permitted_range: {args.counterfactual_permitted_range}")
        _logger.info(f"features_to_vary: {args.counterfactual_features_to_vary}")
        _logger.info(f"feature_importance: {args.counterfactual_feature_importance}")
        rai_i.counterfactual.add(
            total_CFs=args.counterfactual_total_CFs,
            method=args.counterfactual_method,
            desired_class=args.counterfactual_desired_class,
            desired_range=args.counterfactual_desired_range,
            permitted_range=args.counterfactual_permitted_range,
            features_to_vary=args.counterfactual_features_to_vary,
            feature_importance=args.counterfactual_feature_importance,
        )
        included_tools[RAIToolType.COUNTERFACTUAL] = True

    if args.enable_error_analysis:
        _logger.info("Adding error analysis")
        _logger.info(f"max_depth: {args.error_analysis_max_depth}")
        _logger.info(f"num_leaves: {args.error_analysis_num_leaves}")
        _logger.info(f"min_child_samples: {args.error_analysis_min_child_samples}")
        _logger.info(f"filter_features: {args.error_analysis_filter_features}")
        rai_i.error_analysis.add(
            max_depth=args.error_analysis_max_depth,
            num_leaves=args.error_analysis_num_leaves,
            min_child_samples=args.error_analysis_min_child_samples,
            filter_features=args.error_analysis_filter_features,
        )
        included_tools[RAIToolType.ERROR_ANALYSIS] = True

    if args.enable_explanation:
        _logger.info("Adding explanation")
        rai_i.explainer.add()
        included_tools[RAIToolType.EXPLANATION] = True

    _logger.info("Triggering computation")
    rai_i.compute()

    dashboard_info = {
        DashboardInfo.RAI_INSIGHTS_RUN_ID_KEY: str(my_run.id),
        DashboardInfo.RAI_INSIGHTS_MODEL_ID_KEY: args.model_id,
        DashboardInfo.RAI_INSIGHTS_CONSTRUCTOR_ARGS_KEY: constructor_args,
    }

    _logger.info("Doing local save and upload")
    binary_dir = "./responsibleai"
    rai_i.save(binary_dir)
    # write dashboard_info to rai_insights.json and upload artifacts
    dashboardinfo_path = os.path.join(
        binary_dir, DashboardInfo.RAI_INSIGHTS_PARENT_FILENAME
    )
    with open(dashboardinfo_path, "w") as of:
        json.dump(dashboard_info, of)
    my_run.upload_folder("dashboard", binary_dir)

    _logger.info("Saving UX JSON")
    rai_data = rai_i.get_data()
    rai_dict = serialize_json_safe(rai_data)
    json_filename = "dashboard.json"
    output_path = json_filename
    with open(output_path, "w") as json_file:
        json.dump(rai_dict, json_file)
    my_run.upload_file("ux_json", output_path)

    _logger.info("Adding properties to run")

    add_properties_to_gather_run(dashboard_info, included_tools, args.task_type)
    _logger.info("Processing completed")