in src/responsibleai/rai_analyse/gather_rai_insights.py [0:0]
def main(args):
dashboard_info = load_dashboard_info_file(args.constructor)
_logger.info("Constructor info: {0}".format(dashboard_info))
with tempfile.TemporaryDirectory() as incoming_temp_dir:
incoming_dir = Path(incoming_temp_dir)
my_run = Run.get_context()
rai_temp = create_rai_insights_from_port_path(my_run, args.constructor)
# We need to fix the following issue in RAIInsights.
# When using the served model wrapper, predictions (currently
# forecasts only) are stored as lists instead of numpy.ndarray.
# This causes the following error when saving the RAIInsights object:
# AttributeError: 'list' object has no attribute 'tolist'
# To remedy this, the code below converts the lists to numpy.ndarray.
# Once it is fixed in RAIInsights, this code can be removed.
for method_name in ['forecast', 'forecast_quantiles']:
field_name = f"_{method_name}_output"
if field_name in rai_temp.__dict__ and \
isinstance(rai_temp._forecast_output, list):
setattr(rai_temp, field_name, np.array(rai_temp._forecast_output))
rai_temp.save(incoming_temp_dir)
print("Saved rai_temp")
print_dir_tree(incoming_temp_dir)
print("=======")
create_rai_tool_directories(incoming_dir)
_logger.info("Saved empty RAI Insights input to temporary directory")
insight_paths = [
args.insight_1,
args.insight_2,
args.insight_3,
args.insight_4,
]
included_tools: Dict[str, bool] = {
RAIToolType.CAUSAL: False,
RAIToolType.COUNTERFACTUAL: False,
RAIToolType.ERROR_ANALYSIS: False,
RAIToolType.EXPLANATION: False,
}
for i in range(len(insight_paths)):
current_insight_arg = insight_paths[i]
if current_insight_arg is not None:
current_insight_path = Path(current_insight_arg)
_logger.info("Checking dashboard info")
insight_info = load_dashboard_info_file(current_insight_path)
_logger.info("Insight info: {0}".format(insight_info))
# Cross check against the constructor
if insight_info != dashboard_info:
err_string = _DASHBOARD_CONSTRUCTOR_MISMATCH.format(i + 1)
raise ValueError(err_string)
# Copy the data
_logger.info("Copying insight {0}".format(i + 1))
tool = copy_insight_to_raiinsights(incoming_dir, current_insight_path)
# Can only have one instance of each tool
if included_tools[tool]:
err_string = _DUPLICATE_TOOL.format(i + 1, tool)
raise ValueError(err_string)
included_tools[tool] = True
else:
_logger.info("Insight {0} is None".format(i + 1))
_logger.info("Tool summary: {0}".format(included_tools))
if rai_temp.task_type == "forecasting":
# Set model serving port to arbitrary value to avoid loading the
# model.
# Forecasting uses the ServedModelWrapper which cannot be
# (de-)serialized like other models.
# Note that the port number is selected arbitrarily.
# The model is not actually being served.
os.environ["RAI_MODEL_SERVING_PORT"] = str(MLFLOW_MODEL_SERVER_PORT)
rai_i = RAIInsights.load(incoming_dir)
_logger.info("Object loaded")
rai_i.save(args.dashboard)
# update dashboard info with gather job run id and write
dashboard_info[DashboardInfo.RAI_INSIGHTS_GATHER_RUN_ID_KEY] = str(my_run.id)
output_file = os.path.join(
args.dashboard, DashboardInfo.RAI_INSIGHTS_PARENT_FILENAME
)
with open(output_file, "w") as of:
json.dump(dashboard_info, of, default=default_json_handler)
_logger.info("Saved dashboard to output")
rai_data = rai_i.get_data()
rai_dict = serialize_json_safe(rai_data)
json_filename = "dashboard.json"
output_path = Path(args.ux_json) / json_filename
with open(output_path, "w") as json_file:
json.dump(rai_dict, json_file)
_logger.info("Dashboard JSON written")
add_properties_to_gather_run(dashboard_info, included_tools)
_logger.info("Processing completed")