in src/responsibleai/rai_analyse/create_rai_insights.py [0:0]
def main(args):
my_run = Run.get_context()
_logger.info("Dealing with initialization dataset")
train_df = load_dataset(args.train_dataset)
_logger.info("Dealing with evaluation dataset")
test_df = load_dataset(args.test_dataset)
if args.model_info_path is None and (
args.model_input is None or args.model_info is None
):
raise UserConfigValidationException("Either model info or model input needs to be supplied.")
model_estimator = None
model_id = None
# For now, the separate conda env will only be used for forecasting.
# At a later point, we might enable this for all task types.
use_separate_conda_env = args.task_type == "forecasting"
if args.model_info_path:
model_id = fetch_model_id(args.model_info_path)
_logger.info("Loading model: {0}".format(model_id))
model_estimator = load_mlflow_model(
workspace=my_run.experiment.workspace,
use_model_dependency=args.use_model_dependency,
use_separate_conda_env=use_separate_conda_env,
model_id=model_id,
)
elif args.model_input and args.model_info:
model_id = args.model_info
_logger.info("Loading model: {0}".format(model_id))
model_estimator = load_mlflow_model(
workspace=my_run.experiment.workspace,
use_model_dependency=args.use_model_dependency,
use_separate_conda_env=use_separate_conda_env,
model_path=args.model_input,
)
# unwrap the model if it's an sklearn wrapper
if model_estimator.__class__.__name__ == "_SklearnModelWrapper":
model_estimator = model_estimator.sklearn_model
constructor_args = create_constructor_arg_dict(args)
# Make sure that it actually loads
_logger.info("Creating RAIInsights object")
_ = RAIInsights(
model=model_estimator, train=train_df, test=test_df, forecasting_enabled=True, **constructor_args
)
_logger.info("Saving JSON for tool components")
output_dict = {
DashboardInfo.RAI_INSIGHTS_RUN_ID_KEY: str(my_run.id),
DashboardInfo.RAI_INSIGHTS_MODEL_ID_KEY: model_id,
DashboardInfo.RAI_INSIGHTS_CONSTRUCTOR_ARGS_KEY: constructor_args,
DashboardInfo.RAI_INSIGHTS_TRAIN_DATASET_ID_KEY: get_train_dataset_id(my_run),
DashboardInfo.RAI_INSIGHTS_TEST_DATASET_ID_KEY: get_test_dataset_id(my_run),
DashboardInfo.RAI_INSIGHTS_DASHBOARD_TITLE_KEY: args.title,
DashboardInfo.RAI_INSIGHTS_INPUT_ARGS_KEY: vars(args),
}
output_file = os.path.join(
args.output_path, DashboardInfo.RAI_INSIGHTS_PARENT_FILENAME
)
with open(output_file, "w") as of:
json.dump(output_dict, of, default=default_json_handler)
_logger.info("Copying train data files")
copy_input_data(
args.train_dataset,
os.path.join(args.output_path, DashboardInfo.TRAIN_FILES_DIR),
)
_logger.info("Copying test data files")
copy_input_data(
args.test_dataset, os.path.join(args.output_path, DashboardInfo.TEST_FILES_DIR)
)