in single-rai-job/src/run_rai.py [0:0]
def main(args):
_logger.info(f"responsibleai=={responsibleai_version}")
my_run = Run.get_context()
my_ws = my_run.experiment.workspace
_logger.info("Dealing with initialization dataset")
train_df = load_tabular_dataset(args.train_dataset_id, my_ws)
_logger.info("Dealing with evaluation dataset")
test_df = load_tabular_dataset(args.test_dataset_id, my_ws)
_logger.info("Loading model: {0}".format(args.model_id))
model_estimator = load_mlflow_model(my_ws, args.model_id)
constructor_args = create_constructor_arg_dict(args)
# Create the RAIInsights object
_logger.info("Creating RAIInsights object")
rai_i = RAIInsights(
model=model_estimator, train=train_df, test=test_df, **constructor_args
)
included_tools: Dict[str, bool] = {
RAIToolType.CAUSAL: False,
RAIToolType.COUNTERFACTUAL: False,
RAIToolType.ERROR_ANALYSIS: False,
RAIToolType.EXPLANATION: False,
}
_logger.info("Checking individual tools")
if args.enable_causal:
_logger.info("Adding causal")
_logger.info(f"treatement_features: {args.causal_treatment_features}")
_logger.info(f"heterogeneity_features: {args.causal_heterogeneity_features}")
_logger.info(f"nuisance_model: {args.causal_nuisance_model}")
_logger.info(f"heterogeneity_model: {args.causal_heterogeneity_model}")
_logger.info(f"alpha: {args.causal_alpha}")
_logger.info(
f"upper_bound_on_cat_expansion: {args.causal_upper_bound_on_cat_expansion}"
)
_logger.info(f"treatment_cost: {args.causal_treatment_cost}")
_logger.info(f"min_tree_leaf_samples: {args.causal_min_tree_leaf_samples}")
_logger.info(f"max_tree_depth: {args.causal_max_tree_depth}")
_logger.info(f"skip_cat_limit_checks: {args.causal_skip_cat_limit_checks}")
_logger.info(f"categories: {args.causal_categories}")
_logger.info(f"n_jobs: {args.causal_n_jobs}")
_logger.info(f"verbose: {args.causal_verbose}")
_logger.info(f"random_state: {args.causal_random_state}")
rai_i.causal.add(
treatment_features=args.causal_treatment_features,
heterogeneity_features=args.causal_heterogeneity_features,
nuisance_model=args.causal_nuisance_model,
heterogeneity_model=args.causal_heterogeneity_model,
alpha=args.causal_alpha,
upper_bound_on_cat_expansion=args.causal_upper_bound_on_cat_expansion,
treatment_cost=args.causal_treatment_cost,
min_tree_leaf_samples=args.causal_min_tree_leaf_samples,
max_tree_depth=args.causal_max_tree_depth,
skip_cat_limit_checks=args.causal_skip_cat_limit_checks,
categories=args.causal_categories,
n_jobs=args.causal_n_jobs,
verbose=args.causal_verbose,
random_state=args.causal_random_state,
)
included_tools[RAIToolType.CAUSAL] = True
if args.enable_counterfactual:
_logger.info("Adding counterfactuals")
_logger.info(f"total_CFs: {args.counterfactual_total_CFs}")
_logger.info(f"method: {args.counterfactual_method}")
_logger.info(f"desired_class: {args.counterfactual_desired_class}")
_logger.info(f"desired_range: {args.counterfactual_desired_range}")
_logger.info(f"permitted_range: {args.counterfactual_permitted_range}")
_logger.info(f"features_to_vary: {args.counterfactual_features_to_vary}")
_logger.info(f"feature_importance: {args.counterfactual_feature_importance}")
rai_i.counterfactual.add(
total_CFs=args.counterfactual_total_CFs,
method=args.counterfactual_method,
desired_class=args.counterfactual_desired_class,
desired_range=args.counterfactual_desired_range,
permitted_range=args.counterfactual_permitted_range,
features_to_vary=args.counterfactual_features_to_vary,
feature_importance=args.counterfactual_feature_importance,
)
included_tools[RAIToolType.COUNTERFACTUAL] = True
if args.enable_error_analysis:
_logger.info("Adding error analysis")
_logger.info(f"max_depth: {args.error_analysis_max_depth}")
_logger.info(f"num_leaves: {args.error_analysis_num_leaves}")
_logger.info(f"min_child_samples: {args.error_analysis_min_child_samples}")
_logger.info(f"filter_features: {args.error_analysis_filter_features}")
rai_i.error_analysis.add(
max_depth=args.error_analysis_max_depth,
num_leaves=args.error_analysis_num_leaves,
min_child_samples=args.error_analysis_min_child_samples,
filter_features=args.error_analysis_filter_features,
)
included_tools[RAIToolType.ERROR_ANALYSIS] = True
if args.enable_explanation:
_logger.info("Adding explanation")
rai_i.explainer.add()
included_tools[RAIToolType.EXPLANATION] = True
_logger.info("Triggering computation")
rai_i.compute()
dashboard_info = {
DashboardInfo.RAI_INSIGHTS_RUN_ID_KEY: str(my_run.id),
DashboardInfo.RAI_INSIGHTS_MODEL_ID_KEY: args.model_id,
DashboardInfo.RAI_INSIGHTS_CONSTRUCTOR_ARGS_KEY: constructor_args,
}
_logger.info("Doing local save and upload")
binary_dir = "./responsibleai"
rai_i.save(binary_dir)
# write dashboard_info to rai_insights.json and upload artifacts
dashboardinfo_path = os.path.join(
binary_dir, DashboardInfo.RAI_INSIGHTS_PARENT_FILENAME
)
with open(dashboardinfo_path, "w") as of:
json.dump(dashboard_info, of)
my_run.upload_folder("dashboard", binary_dir)
_logger.info("Saving UX JSON")
rai_data = rai_i.get_data()
rai_dict = serialize_json_safe(rai_data)
json_filename = "dashboard.json"
output_path = json_filename
with open(output_path, "w") as json_file:
json.dump(rai_dict, json_file)
my_run.upload_file("ux_json", output_path)
_logger.info("Adding properties to run")
add_properties_to_gather_run(dashboard_info, included_tools, args.task_type)
_logger.info("Processing completed")