def main()

in src/responsibleai/rai_analyse/create_rai_insights.py [0:0]


def main(args):
    my_run = Run.get_context()

    _logger.info("Dealing with initialization dataset")
    train_df = load_dataset(args.train_dataset)

    _logger.info("Dealing with evaluation dataset")
    test_df = load_dataset(args.test_dataset)

    if args.model_info_path is None and (
        args.model_input is None or args.model_info is None
    ):
        raise UserConfigValidationException("Either model info or model input needs to be supplied.")

    model_estimator = None
    model_id = None
    # For now, the separate conda env will only be used for forecasting.
    # At a later point, we might enable this for all task types.
    use_separate_conda_env = args.task_type == "forecasting"
    if args.model_info_path:
        model_id = fetch_model_id(args.model_info_path)
        _logger.info("Loading model: {0}".format(model_id))
        model_estimator = load_mlflow_model(
            workspace=my_run.experiment.workspace,
            use_model_dependency=args.use_model_dependency,
            use_separate_conda_env=use_separate_conda_env,
            model_id=model_id,
        )
    elif args.model_input and args.model_info:
        model_id = args.model_info
        _logger.info("Loading model: {0}".format(model_id))
        model_estimator = load_mlflow_model(
            workspace=my_run.experiment.workspace,
            use_model_dependency=args.use_model_dependency,
            use_separate_conda_env=use_separate_conda_env,
            model_path=args.model_input,
        )

    # unwrap the model if it's an sklearn wrapper
    if model_estimator.__class__.__name__ == "_SklearnModelWrapper":
        model_estimator = model_estimator.sklearn_model

    constructor_args = create_constructor_arg_dict(args)

    # Make sure that it actually loads
    _logger.info("Creating RAIInsights object")
    _ = RAIInsights(
        model=model_estimator, train=train_df, test=test_df, forecasting_enabled=True, **constructor_args
    )

    _logger.info("Saving JSON for tool components")
    output_dict = {
        DashboardInfo.RAI_INSIGHTS_RUN_ID_KEY: str(my_run.id),
        DashboardInfo.RAI_INSIGHTS_MODEL_ID_KEY: model_id,
        DashboardInfo.RAI_INSIGHTS_CONSTRUCTOR_ARGS_KEY: constructor_args,
        DashboardInfo.RAI_INSIGHTS_TRAIN_DATASET_ID_KEY: get_train_dataset_id(my_run),
        DashboardInfo.RAI_INSIGHTS_TEST_DATASET_ID_KEY: get_test_dataset_id(my_run),
        DashboardInfo.RAI_INSIGHTS_DASHBOARD_TITLE_KEY: args.title,
        DashboardInfo.RAI_INSIGHTS_INPUT_ARGS_KEY: vars(args),
    }
    output_file = os.path.join(
        args.output_path, DashboardInfo.RAI_INSIGHTS_PARENT_FILENAME
    )
    with open(output_file, "w") as of:
        json.dump(output_dict, of, default=default_json_handler)

    _logger.info("Copying train data files")
    copy_input_data(
        args.train_dataset,
        os.path.join(args.output_path, DashboardInfo.TRAIN_FILES_DIR),
    )
    _logger.info("Copying test data files")
    copy_input_data(
        args.test_dataset, os.path.join(args.output_path, DashboardInfo.TEST_FILES_DIR)
    )