src/responsibleai/rai_analyse/create_counterfactual.py (67 lines of code) (raw):

# Copyright (c) Microsoft Corporation. # Licensed under the MIT License. import argparse import logging from arg_helpers import (boolean_parser, json_empty_is_none_parser, str_or_int_parser, str_or_list_parser) from azureml.core import Run from azureml.rai.utils.telemetry import LoggerFactory, track from constants import COMPONENT_NAME, RAIToolType from rai_component_utilities import (copy_dashboard_info_file, create_rai_insights_from_port_path, save_to_output_port) from responsibleai import RAIInsights _logger = logging.getLogger(__file__) _ai_logger = None def _get_logger(): global _ai_logger if _ai_logger is None: run = Run.get_context() module_name = run.properties["azureml.moduleName"] module_version = run.properties["azureml.moduleid"] _ai_logger = LoggerFactory.get_logger( __file__, module_name, module_version, COMPONENT_NAME) return _ai_logger _get_logger() def parse_args(): # setup arg parser parser = argparse.ArgumentParser() parser.add_argument("--rai_insights_dashboard", type=str, required=True) parser.add_argument("--total_CFs", type=int, required=True) parser.add_argument("--method", type=str) parser.add_argument("--desired_class", type=str_or_int_parser) parser.add_argument("--desired_range", type=json_empty_is_none_parser, help="List") parser.add_argument( "--permitted_range", type=json_empty_is_none_parser, help="Dict" ) parser.add_argument("--features_to_vary", type=str_or_list_parser) parser.add_argument("--feature_importance", type=boolean_parser) parser.add_argument("--counterfactual_path", type=str) # parse args args = parser.parse_args() # return args return args @track(_get_logger) def main(args): my_run = Run.get_context() # Load the RAI Insights object rai_i: RAIInsights = create_rai_insights_from_port_path( my_run, args.rai_insights_dashboard ) # Add the counterfactual rai_i.counterfactual.add( total_CFs=args.total_CFs, method=args.method, desired_class=args.desired_class, desired_range=args.desired_range, permitted_range=args.permitted_range, features_to_vary=args.features_to_vary, feature_importance=args.feature_importance, ) _logger.info("Added counterfactual") # Compute rai_i.compute() _logger.info("Computation complete") # Save save_to_output_port(rai_i, args.counterfactual_path, RAIToolType.COUNTERFACTUAL) _logger.info("Saved to output port") # Copy the dashboard info file copy_dashboard_info_file(args.rai_insights_dashboard, args.counterfactual_path) _logger.info("Completing") # run script if __name__ == "__main__": # add space in logs print("*" * 60) print("\n\n") # parse args args = parse_args() # run main function main(args) # add space in logs print("*" * 60) print("\n\n")