def show_explanation()

in src/infer_location.py [0:0]


    def show_explanation(self, query):
        # Convert the background text to a 2D numpy array
        background_text = np.array([["This is a sample background text for SHAP."]])

        # Initialize KernelExplainer with the 2D numpy array background
        explainer = shap.KernelExplainer(self.shap_predict_wrapper, background_text)
        text_to_explain = np.array([query])
        # Generate SHAP values for the tokens
        shap_values = explainer.shap_values(text_to_explain)
        num_tokens = 64  # Adjust if your input length is different
        num_classes = 11
        reshaped_shap_values = shap_values[0].reshape(num_tokens, num_classes)
        # aggregated_shap_values = reshaped_shap_values.sum(axis=1)
        masks = [idx for idx, mask in enumerate(self.tokenizer(text_to_explain[0], truncation=True, padding="max_length", max_length=64)['attention_mask']) if mask]
        text_tokens = self.tokenizer.convert_ids_to_tokens(self.tokenizer(text_to_explain[0], truncation=True, padding="max_length", max_length=64)['input_ids'])
        trimmed_shap_values = reshaped_shap_values[:len(masks), :]
        shap_values_by_class = pd.DataFrame(trimmed_shap_values, index=text_tokens[:len(masks)], columns=list(self.label_map.values()))
        return shap_values_by_class