def main()

in src/responsibleai/rai_analyse/create_causal.py [0:0]


def main(args):
    my_run = Run.get_context()
    # Load the RAI Insights object
    rai_i: RAIInsights = create_rai_insights_from_port_path(
        my_run, args.rai_insights_dashboard
    )

    # Add the causal analysis
    rai_i.causal.add(
        treatment_features=args.treatment_features,
        heterogeneity_features=args.heterogeneity_features,
        nuisance_model=args.nuisance_model,
        heterogeneity_model=args.heterogeneity_model,
        alpha=args.alpha,
        upper_bound_on_cat_expansion=args.upper_bound_on_cat_expansion,
        treatment_cost=args.treatment_cost,
        min_tree_leaf_samples=args.min_tree_leaf_samples,
        max_tree_depth=args.max_tree_depth,
        skip_cat_limit_checks=args.skip_cat_limit_checks,
        categories=args.categories,
        n_jobs=args.n_jobs,
        verbose=args.verbose,
        random_state=args.random_state,
    )
    _logger.info("Added causal")

    # Compute
    rai_i.compute()
    _logger.info("Computation complete")

    # Save
    save_to_output_port(rai_i, args.causal_path, RAIToolType.CAUSAL)
    _logger.info("Saved computation to output port")

    # Copy the dashboard info file
    copy_dashboard_info_file(args.rai_insights_dashboard, args.causal_path)

    _logger.info("Completing")