in src/responsibleai/rai_analyse/create_causal.py [0:0]
def main(args):
my_run = Run.get_context()
# Load the RAI Insights object
rai_i: RAIInsights = create_rai_insights_from_port_path(
my_run, args.rai_insights_dashboard
)
# Add the causal analysis
rai_i.causal.add(
treatment_features=args.treatment_features,
heterogeneity_features=args.heterogeneity_features,
nuisance_model=args.nuisance_model,
heterogeneity_model=args.heterogeneity_model,
alpha=args.alpha,
upper_bound_on_cat_expansion=args.upper_bound_on_cat_expansion,
treatment_cost=args.treatment_cost,
min_tree_leaf_samples=args.min_tree_leaf_samples,
max_tree_depth=args.max_tree_depth,
skip_cat_limit_checks=args.skip_cat_limit_checks,
categories=args.categories,
n_jobs=args.n_jobs,
verbose=args.verbose,
random_state=args.random_state,
)
_logger.info("Added causal")
# Compute
rai_i.compute()
_logger.info("Computation complete")
# Save
save_to_output_port(rai_i, args.causal_path, RAIToolType.CAUSAL)
_logger.info("Saved computation to output port")
# Copy the dashboard info file
copy_dashboard_info_file(args.rai_insights_dashboard, args.causal_path)
_logger.info("Completing")