in assets/responsibleai/vision/components/src/rai_vision_insights.py [0:0]
def main(args):
_logger.info(f"Dataset type is {args.dataset_type}")
# Load evaluation dataset
_logger.info("Loading evaluation dataset")
test_df, is_test_mltable_from_datastore = load_dataset(
args.test_dataset,
args.dataset_type
)
if args.model_input is None or args.model_info is None:
raise ValueError(
"Both model info and model input need to be supplied."
)
model_id = args.model_info
_logger.info(f"Loading model: {model_id}")
tracking_uri = mlflow.get_tracking_uri()
model_serializer = ModelSerializer(
model_id,
model_type=args.model_type,
use_model_dependency=args.use_model_dependency,
use_conda=args.use_conda,
tracking_uri=tracking_uri)
image_model = model_serializer.load("Ignored path")
if args.task_type == TaskType.MULTILABEL_IMAGE_CLASSIFICATION:
target_column = get_from_args(
args=args,
arg_name="target_column_name",
custom_parser=json_empty_is_none_parser,
allow_none=True
)
else:
target_column = args.target_column_name
image_width_in_inches = args.image_width_in_inches
max_evals = args.max_evals
num_masks = args.num_masks
mask_res = args.mask_res
feature_metadata = FeatureMetadata()
feature_metadata.categorical_features = get_from_args(
args=args,
arg_name="categorical_metadata_features",
custom_parser=json_empty_is_none_parser,
allow_none=True
)
feature_metadata.dropped_features = get_from_args(
args=args,
arg_name="dropped_metadata_features",
custom_parser=json_empty_is_none_parser,
allow_none=True
)
test_data_path = None
image_downloader = None
if is_test_mltable_from_datastore:
test_data_path = args.test_dataset
image_downloader = DownloadManager
# Need to drop extra "portable path" column added by DownloadManager
# to exclude it from feature metadata
if feature_metadata.dropped_features is None:
feature_metadata.dropped_features = [PORTABLE_PATH]
else:
feature_metadata.dropped_features.append(PORTABLE_PATH)
_logger.info("Creating RAI Vision Insights")
device_id = args.device
device = Device.CPU.value
is_gpu = False
if device_id >= 0:
device = Device.CUDA.value
is_gpu = True
is_available = str(torch.cuda.is_available())
_logger.info("Using CUDA device. Is CUDA available: " + is_available)
rai_vi = RAIVisionInsights(
model=image_model,
test=test_df,
target_column=target_column,
task_type=args.task_type,
classes=get_from_args(
args=args,
arg_name="classes",
custom_parser=json_empty_is_none_parser,
allow_none=True
),
maximum_rows_for_test=args.maximum_rows_for_test_dataset,
serializer=model_serializer,
test_data_path=test_data_path,
image_downloader=image_downloader,
feature_metadata=feature_metadata,
image_width=image_width_in_inches,
max_evals=max_evals,
num_masks=num_masks,
mask_res=mask_res,
device=device
)
included_tools: Dict[str, bool] = {
RAIToolType.CAUSAL: False,
RAIToolType.COUNTERFACTUAL: False,
RAIToolType.ERROR_ANALYSIS: False,
RAIToolType.EXPLANATION: False,
}
if args.precompute_explanation:
_logger.info("Adding explanation")
rai_vi.explainer.add()
included_tools[RAIToolType.EXPLANATION] = True
if args.enable_error_analysis:
_logger.info("Adding error analysis")
rai_vi.error_analysis.add()
included_tools[RAIToolType.ERROR_ANALYSIS] = True
_logger.info("Starting computation")
# get keyword arguments
automl_xai_args = dict()
for xai_arg in ExplainabilityLiterals.XAI_ARGS_GROUP:
xai_arg_value = getattr(args, xai_arg)
if xai_arg_value is not None:
automl_xai_args[xai_arg] = xai_arg_value
_logger.info(f"xai_arg: {xai_arg}, value: {xai_arg_value}")
_logger.info(f"automl_xai_args: {automl_xai_args}")
rai_vi.compute(**automl_xai_args)
_logger.info("Computation complete")
rai_vi.save(args.dashboard)
_logger.info("Saved dashboard to output")
rai_data = rai_vi.get_data()
rai_dict = serialize_json_safe(rai_data)
json_filename = "dashboard.json"
output_path = Path(args.ux_json) / json_filename
with open(output_path, "w") as json_file:
json.dump(rai_dict, json_file)
_logger.info("Dashboard JSON written")
dashboard_info = dict()
dashboard_info[DashboardInfo.RAI_INSIGHTS_MODEL_ID_KEY] = model_id
add_properties_to_gather_run(dashboard_info, included_tools, is_gpu)
_logger.info("Processing completed")