def predict_fn()

in sagemaker_studio/containers/model/src/explaining.py [0:0]


def predict_fn(request, model_assets):
    data = request['data']
    entities = request['entities']
    response = {}
    if 'data' in entities:
        response['data'] = data
    features = preprocess_fn(data, model_assets)
    if 'features' in entities:
        feature_names = model_assets["features_schema"].item_titles
        feature_values = features[0].tolist()
        response['features'] = {k: v for k, v in zip(feature_names, feature_values)}
    if 'descriptions' in entities:
        response['descriptions'] = model_assets["features_schema"].item_descriptions_dict
    if 'prediction' in entities:
        prediction = model_assets["classifier"].predict_proba(features)
        # take first sample (idx=0)
        # and second probability (idx=1) corresponding to the positive class
        response['prediction'] = prediction[0][1].tolist()
    if ('explanation_shap_values' in entities) or ('explanation_shap_interaction_values' in entities):
        explanation = {}
        expected_value = model_assets["explainer"].expected_value
        # see https://github.com/slundberg/shap/issues/729: handle both cases
        if expected_value.shape == (1,):
            explanation['expected_value'] = expected_value[0].tolist()
        else:
            explanation['expected_value'] = expected_value[1].tolist()
        if 'explanation_shap_values' in entities:
            # second probability (idx=1) corresponding to the positive class
            # and take first sample (idx=0)
            feature_names = model_assets["features_schema"].item_titles
            with warnings.catch_warnings():
                warnings.simplefilter("ignore")
                shap_values = model_assets["explainer"].shap_values(features)[1][0]
            explanation['shap_values'] = {k: v for k, v in zip(feature_names, shap_values.tolist())}
        if 'explanation_shap_interaction_values' in entities:
            labels = model_assets["features_schema"].item_titles
            # take first sample (idx=0)
            values = model_assets["explainer"].shap_interaction_values(features)[0].tolist()
            explanation['shap_interaction_values'] = {
                'labels': labels,
                'values': values
            }
        # see https://github.com/slundberg/shap/issues/729: setting back to original
        model_assets["explainer"].expected_value = expected_value
        response['explanation'] = explanation
    return response