in assets/responsibleai/text/components/src/rai_text_insights.py [0:0]
def main(args):
_logger.info("Loading evaluation dataset")
test_df = load_dataset(args.test_dataset)
if args.model_input is None or args.model_info is None:
raise ValueError(
"Both model info and model input need to be supplied.")
model_meta = mlflow.models.Model.load(
os.path.join(args.model_input, "MLmodel"))
model_type = None
if ModelTypes.HFTRANSFORMERS in model_meta.flavors:
model_type = ModelTypes.HFTRANSFORMERS
model_id = args.model_info
_logger.info("Loading model: {0}".format(model_id))
tracking_uri = mlflow.get_tracking_uri()
model_serializer = ModelSerializer(
model_id,
mlflow_model=args.model_input,
model_type=model_type,
use_model_dependency=args.use_model_dependency,
use_conda=args.use_conda,
task_type=args.task_type,
tracking_uri=tracking_uri)
text_model = model_serializer.load("Ignored path")
text_column = args.text_column_name
if args.task_type == TaskType.MULTILABEL_TEXT_CLASSIFICATION:
target_column = get_from_args(
args=args,
arg_name="target_column_name",
custom_parser=json_empty_is_none_parser,
allow_none=True
)
else:
target_column = args.target_column_name
feature_metadata = FeatureMetadata()
feature_metadata.categorical_features = get_from_args(
args=args,
arg_name="categorical_metadata_features",
custom_parser=json_empty_is_none_parser,
allow_none=True
)
feature_metadata.dropped_features = get_from_args(
args=args,
arg_name="dropped_metadata_features",
custom_parser=json_empty_is_none_parser,
allow_none=True
)
_logger.info("Creating RAI Text Insights")
rai_ti = RAITextInsights(
model=text_model,
test=test_df,
target_column=target_column,
task_type=args.task_type,
classes=get_from_args(
args, "classes", custom_parser=json_empty_is_none_parser,
allow_none=True
),
maximum_rows_for_test=args.maximum_rows_for_test_dataset,
serializer=model_serializer,
feature_metadata=feature_metadata,
text_column=text_column
)
included_tools: Dict[str, bool] = {
RAIToolType.CAUSAL: False,
RAIToolType.COUNTERFACTUAL: False,
RAIToolType.ERROR_ANALYSIS: False,
RAIToolType.EXPLANATION: False,
}
if args.enable_explanation:
_logger.info("Adding explanation")
rai_ti.explainer.add()
included_tools[RAIToolType.EXPLANATION] = True
if args.enable_error_analysis:
_logger.info("Adding error analysis")
rai_ti.error_analysis.add()
included_tools[RAIToolType.ERROR_ANALYSIS] = True
_logger.info("Starting computation")
rai_ti.compute()
_logger.info("Computation complete")
rai_ti.save(args.dashboard)
_logger.info("Saved dashboard to output")
rai_data = rai_ti.get_data()
rai_dict = serialize_json_safe(rai_data)
json_filename = "dashboard.json"
output_path = Path(args.ux_json) / json_filename
with open(output_path, "w") as json_file:
json.dump(rai_dict, json_file)
_logger.info("Dashboard JSON written")
dashboard_info = dict()
dashboard_info[DashboardInfo.RAI_INSIGHTS_MODEL_ID_KEY] = model_id
add_properties_to_gather_run(dashboard_info, included_tools)
_logger.info("Processing completed")