in single-rai-job/src/run_rai.py [0:0]
def parse_args():
# setup arg parser
parser = argparse.ArgumentParser()
# Constructor arguments
parser.add_argument("--title", type=str, required=True)
parser.add_argument(
"--task_type", type=str, required=True, choices=["classification", "regression", "forecasting"]
)
parser.add_argument("--model_id", type=str, help="name:version", required=True)
parser.add_argument("--train_dataset_id", type=str, required=True)
parser.add_argument("--test_dataset_id", type=str, required=True)
parser.add_argument("--target_column_name", type=str, required=True)
parser.add_argument("--maximum_rows_for_test_dataset", type=int, default=5000)
parser.add_argument(
"--categorical_column_names", type=str, help="Optional[List[str]]"
)
parser.add_argument("--classes", type=str, help="Optional[List[str]]")
# Causal arguments
parser.add_argument("--enable_causal", type=boolean_parser, required=True)
parser.add_argument(
"--causal_treatment_features", type=json.loads, help="List[str]"
)
parser.add_argument(
"--causal_heterogeneity_features",
type=json.loads,
help="Optional[List[str]] use 'null' to skip",
)
parser.add_argument(
"--causal_nuisance_model", type=str, choices=["linear", "automl"]
)
parser.add_argument(
"--causal_heterogeneity_model", type=str, choices=["linear", "forest"]
)
parser.add_argument("--causal_alpha", type=float)
parser.add_argument("--causal_upper_bound_on_cat_expansion", type=int)
parser.add_argument(
"--causal_treatment_cost",
type=float_or_json_parser,
help="Union[float, List[Union[float, np.ndarray]]]",
)
parser.add_argument("--causal_min_tree_leaf_samples", type=int)
parser.add_argument("--causal_max_tree_depth", type=int)
parser.add_argument("--causal_skip_cat_limit_checks", type=boolean_parser)
parser.add_argument("--causal_categories", type=str_or_list_parser)
parser.add_argument("--causal_n_jobs", type=int)
parser.add_argument("--causal_verbose", type=int)
parser.add_argument("--causal_random_state", type=int_or_none_parser)
# Counterfactual arguments
parser.add_argument("--enable_counterfactual", type=boolean_parser, required=True)
parser.add_argument("--counterfactual_total_CFs", type=int, required=True)
parser.add_argument("--counterfactual_method", type=str)
parser.add_argument("--counterfactual_desired_class", type=str_or_int_parser)
parser.add_argument(
"--counterfactual_desired_range", type=json_empty_is_none_parser, help="List"
)
parser.add_argument(
"--counterfactual_permitted_range", type=json_empty_is_none_parser, help="Dict"
)
parser.add_argument("--counterfactual_features_to_vary", type=str_or_list_parser)
parser.add_argument("--counterfactual_feature_importance", type=boolean_parser)
# Error analysis arguments
parser.add_argument("--enable_error_analysis", type=boolean_parser, required=True)
parser.add_argument("--error_analysis_max_depth", type=int)
parser.add_argument("--error_analysis_num_leaves", type=int)
parser.add_argument("--error_analysis_min_child_samples", type=int)
parser.add_argument(
"--error_analysis_filter_features", type=json.loads, help="List"
)
# Explanation arguments
parser.add_argument("--enable_explanation", type=boolean_parser, required=True)
# Output arguments
# parser.add_argument("--dashboard", type=str, required=True)
# parser.add_argument("--ux_json", type=str, required=True)
# parse args
args = parser.parse_args()
# return args
return args