training/sagemaker-debugger/cnn_class_activation_maps.ipynb (1,441 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "id": "bab609f9", "metadata": { "papermill": { "duration": 0.019155, "end_time": "2021-05-27T00:13:14.763381", "exception": false, "start_time": "2021-05-27T00:13:14.744226", "status": "completed" }, "tags": [] }, "source": [ "# Using SageMaker debugger to visualize class activation maps in CNNs\n" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "---\n", "\n", "This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook. \n", "\n", "![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-2/sagemaker-debugger|model_specific_realtime_analysis|cnn_class_activation_maps|cnn_class_activation_maps.ipynb)\n", "\n", "---" ] }, { "cell_type": "markdown", "id": "bab609f9", "metadata": { "papermill": { "duration": 0.019155, "end_time": "2021-05-27T00:13:14.763381", "exception": false, "start_time": "2021-05-27T00:13:14.744226", "status": "completed" }, "tags": [] }, "source": [ "\n", "This notebook will demonstrate how to use SageMaker debugger to plot class activations maps for image classification models. A class activation map (saliency map) is a heatmap that highlights the regions in the image that lead the model to make a certain prediction. This is especially useful: \n", "\n", "1. if the model makes a misclassification and it is not clear why; \n", "\n", "2. or to determine if the model takes all important features of an object into account \n", "\n", "In this notebook we will train a [ResNet](https://arxiv.org/abs/1512.03385) model on the [German Traffic Sign Dataset](http://benchmark.ini.rub.de/?section=gtsrb&subsection=news) and we will use SageMaker debugger to plot class activation maps in real-time.\n", "\n", "The following animation shows the saliency map for a particular traffic sign as training progresses. Red highlights the regions with high activation leading to the prediction, blue indicates low activation that are less relevant for the prediction. \n", "\n", "In the beginning the model will do a lot of mis-classifications as it focuses on the wrong image regions e.g. the obstacle in the lower left corner. As training progresses the focus shifts to the center of the image, and the model becomes more and more confident in predicting the class 3 (which is the correct class).\n", "\n", "![](images/example.gif)\n", "\n", "There exist several methods to generate saliency maps e.g. [CAM](http://cnnlocalization.csail.mit.edu/), [GradCAM](https://arxiv.org/abs/1610.02391). The paper [Full-Gradient Representation for Neural Network Visualization [1]](https://arxiv.org/abs/1905.00780) proposes a new method which produces state of the art results. It requires intermediate features and their biases. With SageMaker debugger we can easily retrieve these tensors.\n", "\n", "[1] *Full-Gradient Representation for Neural Network Visualization*: Suraj Srinivas and Francois Fleuret, 2019, 1905.00780, arXiv" ] }, { "cell_type": "markdown", "id": "18bab552", "metadata": { "papermill": { "duration": 0.019147, "end_time": "2021-05-27T00:13:14.801813", "exception": false, "start_time": "2021-05-27T00:13:14.782666", "status": "completed" }, "tags": [] }, "source": [ "### Customize the smdebug hook\n", "\n", "To create saliency maps, the gradients of the prediction with respect to the intermediate features need to be computed. To obtain this information, we have to customize the [smdebug hook](https://github.com/awslabs/sagemaker-debugger/blob/master/smdebug/pytorch/hook.py). The custom hook is defined in [entry_point/custom_hook.py](entry_point/custom_hook.py) During the forward pass, we register a backward hook on the outputs. We also need to get gradients of the input image, so we provide an additional function that registers a backward hook on the input tensor. \n", "\n", "The paper [Full-Gradient Representation for Neural Network Visualization [1]](https://arxiv.org/abs/1905.00780) distinguishes between implicit and explicit biases. Implicit biases include running mean and variance from BatchNorm layers. With SageMaker debugger we only get the explicit biases which equals the beta paramater in the case of BatchNorm layers. We extend the hook to also record running averages and variances for BatchNorm layers.\n" ] }, { "cell_type": "markdown", "id": "446c3ed6", "metadata": { "papermill": { "duration": 0.019026, "end_time": "2021-05-27T00:13:14.839884", "exception": false, "start_time": "2021-05-27T00:13:14.820858", "status": "completed" }, "tags": [] }, "source": [ "```python\n", "import smdebug.pytorch as smd\n", " \n", "class CustomHook(smd.Hook):\n", " \n", " #register input image for backward pass, to get image gradients\n", " def image_gradients(self, image):\n", " image.register_hook(self.backward_hook(\"image\"))\n", " \n", " def forward_hook(self, module, inputs, outputs):\n", " module_name = module._module_name\n", " self._write_inputs(module_name, inputs)\n", " \n", " #register outputs for backward pass. this is expensive, so we will only do it during EVAL mode\n", " if self.mode == ModeKeys.EVAL:\n", " outputs.register_hook(self.backward_hook(module_name + \"_output\"))\n", " \n", " #record running mean and var of BatchNorm layers\n", " if isinstance(module, torch.nn.BatchNorm2d):\n", " self._write_outputs(module_name + \".running_mean\", module.running_mean)\n", " self._write_outputs(module_name + \".running_var\", module.running_var)\n", " \n", " self._write_outputs(module_name, outputs)\n", " self.last_saved_step = self.step\n", "```" ] }, { "cell_type": "markdown", "id": "53a82b23", "metadata": { "papermill": { "duration": 0.01905, "end_time": "2021-05-27T00:13:14.877985", "exception": false, "start_time": "2021-05-27T00:13:14.858935", "status": "completed" }, "tags": [] }, "source": [ "### Replace in-place operations\n", "Additionally we need to convert inplace operations, as they can potentially overwrite values that are required to compute gradients. In the case of PyTorch pre-trained ResNet model, ReLU activatons are per default executed inplace. The following code sets `inplace=False`" ] }, { "cell_type": "code", "execution_count": null, "id": "f55135d6", "metadata": { "execution": { "iopub.execute_input": "2021-05-27T00:13:14.920576Z", "iopub.status.busy": "2021-05-27T00:13:14.920061Z", "iopub.status.idle": "2021-05-27T00:13:14.922038Z", "shell.execute_reply": "2021-05-27T00:13:14.922413Z" }, "papermill": { "duration": 0.025524, "end_time": "2021-05-27T00:13:14.922565", "exception": false, "start_time": "2021-05-27T00:13:14.897041", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "def relu_inplace(model):\n", " for child_name, child in model.named_children():\n", " if isinstance(child, torch.nn.ReLU):\n", " setattr(model, child_name, torch.nn.ReLU(inplace=False))\n", " else:\n", " relu_inplace(child)" ] }, { "cell_type": "markdown", "id": "0b4de586", "metadata": { "papermill": { "duration": 0.0192, "end_time": "2021-05-27T00:13:14.961499", "exception": false, "start_time": "2021-05-27T00:13:14.942299", "status": "completed" }, "tags": [] }, "source": [ "### Download the dataset and upload it to Amazon S3\n", "\n", "Now we download the [German Traffic Sign Dataset](http://benchmark.ini.rub.de/?section=gtsrb&subsection=news) and upload it to Amazon S3. The training dataset consists of 43 image classes." ] }, { "cell_type": "code", "execution_count": null, "id": "b7418d4a", "metadata": { "execution": { "iopub.execute_input": "2021-05-27T00:13:15.005339Z", "iopub.status.busy": "2021-05-27T00:13:15.004564Z", "iopub.status.idle": "2021-05-27T00:13:36.967224Z", "shell.execute_reply": "2021-05-27T00:13:36.966784Z" }, "papermill": { "duration": 21.986356, "end_time": "2021-05-27T00:13:36.967340", "exception": false, "start_time": "2021-05-27T00:13:14.980984", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import zipfile\n", "\n", "! wget https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB-Training_fixed.zip\n", "with zipfile.ZipFile(\"GTSRB-Training_fixed.zip\", \"r\") as zip_ref:\n", " zip_ref.extractall(\"./\")" ] }, { "cell_type": "markdown", "id": "ee3dbc0c", "metadata": { "papermill": { "duration": 0.034419, "end_time": "2021-05-27T00:13:37.037208", "exception": false, "start_time": "2021-05-27T00:13:37.002789", "status": "completed" }, "tags": [] }, "source": [ "The test dataset:" ] }, { "cell_type": "code", "execution_count": null, "id": "ee678b41", "metadata": { "execution": { "iopub.execute_input": "2021-05-27T00:13:37.112365Z", "iopub.status.busy": "2021-05-27T00:13:37.111680Z", "iopub.status.idle": "2021-05-27T00:13:48.701498Z", "shell.execute_reply": "2021-05-27T00:13:48.701909Z" }, "papermill": { "duration": 11.629926, "end_time": "2021-05-27T00:13:48.702040", "exception": false, "start_time": "2021-05-27T00:13:37.072114", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import zipfile\n", "\n", "! wget https://sid.erda.dk/public/archives/daaeac0d7ce1152aea9b61d9f1e19370/GTSRB_Final_Test_Images.zip\n", "with zipfile.ZipFile(\"GTSRB_Final_Test_Images.zip\", \"r\") as zip_ref:\n", " zip_ref.extractall(\"./\")" ] }, { "cell_type": "markdown", "id": "eadd258c", "metadata": { "papermill": { "duration": 0.117074, "end_time": "2021-05-27T00:13:48.864146", "exception": false, "start_time": "2021-05-27T00:13:48.747072", "status": "completed" }, "tags": [] }, "source": [ "Now we upload the datasets to the SageMaker default bucket in Amazon S3." ] }, { "cell_type": "code", "execution_count": null, "id": "0b417b93", "metadata": { "execution": { "iopub.execute_input": "2021-05-27T00:13:48.961919Z", "iopub.status.busy": "2021-05-27T00:13:48.961116Z", "iopub.status.idle": "2021-05-27T00:13:49.786431Z", "shell.execute_reply": "2021-05-27T00:13:49.786001Z" }, "papermill": { "duration": 0.87724, "end_time": "2021-05-27T00:13:49.786546", "exception": false, "start_time": "2021-05-27T00:13:48.909306", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import boto3\n", "import sagemaker\n", "import os\n", "\n", "\n", "def upload_to_s3(path, directory_name, bucket, counter=-1):\n", "\n", " print(\"Upload files from\" + path + \" to \" + bucket)\n", " client = boto3.client(\"s3\")\n", "\n", " for path, subdirs, files in os.walk(path):\n", " path = path.replace(\"\\\\\", \"/\")\n", " print(path)\n", " for file in files[0:counter]:\n", " client.upload_file(\n", " os.path.join(path, file),\n", " bucket,\n", " directory_name + \"/\" + path.split(\"/\")[-1] + \"/\" + file,\n", " )\n", "\n", "\n", "boto_session = boto3.Session()\n", "sagemaker_session = sagemaker.Session(boto_session=boto_session)\n", "bucket = sagemaker_session.default_bucket()\n", "\n", "upload_to_s3(\"GTSRB/Training\", directory_name=\"train\", bucket=bucket)\n", "\n", "# we will compute saliency maps for all images in the test dataset, so we will only upload 4 images\n", "upload_to_s3(\"GTSRB/Final_Test\", directory_name=\"test\", bucket=bucket, counter=4)" ] }, { "cell_type": "markdown", "id": "5f7c84e2", "metadata": { "papermill": { "duration": 0.04502, "end_time": "2021-05-27T00:13:49.876684", "exception": false, "start_time": "2021-05-27T00:13:49.831664", "status": "completed" }, "tags": [] }, "source": [ "Before starting the SageMaker training job, we need to install some libraries. We will use `smdebug` library to read, filter and analyze raw tensors that are stored in Amazon S3. We will use `opencv-python` library to plot saliency maps as heatmap." ] }, { "cell_type": "code", "execution_count": null, "id": "fab25828", "metadata": {}, "outputs": [], "source": [ "!apt-get update && apt-get install -y python3-opencv" ] }, { "cell_type": "code", "execution_count": null, "id": "8a92d2c6", "metadata": { "execution": { "iopub.execute_input": "2021-05-27T00:13:49.970786Z", "iopub.status.busy": "2021-05-27T00:13:49.970318Z", "iopub.status.idle": "2021-05-27T00:13:49.973729Z", "shell.execute_reply": "2021-05-27T00:13:49.973349Z" }, "papermill": { "duration": 0.052181, "end_time": "2021-05-27T00:13:49.973831", "exception": false, "start_time": "2021-05-27T00:13:49.921650", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import pip\n", "\n", "\n", "def import_or_install(package):\n", " try:\n", " __import__(package)\n", " except ImportError:\n", " pip.main([\"install\", package])" ] }, { "cell_type": "code", "execution_count": null, "id": "94e42997", "metadata": { "execution": { "iopub.execute_input": "2021-05-27T00:13:50.067962Z", "iopub.status.busy": "2021-05-27T00:13:50.067493Z", "iopub.status.idle": "2021-05-27T00:13:50.074685Z", "shell.execute_reply": "2021-05-27T00:13:50.074280Z" }, "papermill": { "duration": 0.055526, "end_time": "2021-05-27T00:13:50.074791", "exception": false, "start_time": "2021-05-27T00:13:50.019265", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import_or_install(\"smdebug\")" ] }, { "cell_type": "code", "execution_count": null, "id": "821ec5f3", "metadata": { "execution": { "iopub.execute_input": "2021-05-27T00:13:50.168978Z", "iopub.status.busy": "2021-05-27T00:13:50.168460Z", "iopub.status.idle": "2021-05-27T00:13:53.692594Z", "shell.execute_reply": "2021-05-27T00:13:53.692136Z" }, "papermill": { "duration": 3.572393, "end_time": "2021-05-27T00:13:53.692709", "exception": false, "start_time": "2021-05-27T00:13:50.120316", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import_or_install(\"opencv-python\")" ] }, { "cell_type": "markdown", "id": "35989595", "metadata": { "papermill": { "duration": 0.046074, "end_time": "2021-05-27T00:13:53.785302", "exception": false, "start_time": "2021-05-27T00:13:53.739228", "status": "completed" }, "tags": [] }, "source": [ "### SageMaker training\n", "\n", "Following code defines the SageMaker estimator. The entry point script [train.py](entry_point/train.py) defines the model training. It downloads a pre-trained ResNet model and performs transfer learning on the German traffic sign dataset.\n", "\n", "#### Debugger hook configuration\n", "Next we define a custom collection where we indicate regular expression of tensor names to be included. Tensors from training phase are saved every 100 steps, while tensors from validation phase are saved every step. A step presents one forward and backward pass." ] }, { "cell_type": "code", "execution_count": null, "id": "a794fcd6", "metadata": { "execution": { "iopub.execute_input": "2021-05-27T00:13:53.884184Z", "iopub.status.busy": "2021-05-27T00:13:53.883706Z", "iopub.status.idle": "2021-05-27T00:13:53.885648Z", "shell.execute_reply": "2021-05-27T00:13:53.886018Z" }, "papermill": { "duration": 0.054477, "end_time": "2021-05-27T00:13:53.886135", "exception": false, "start_time": "2021-05-27T00:13:53.831658", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "from sagemaker.debugger import DebuggerHookConfig, CollectionConfig\n", "\n", "debugger_hook_config = DebuggerHookConfig(\n", " collection_configs=[\n", " CollectionConfig(\n", " name=\"custom_collection\",\n", " parameters={\n", " \"include_regex\": \".*bn|.*bias|.*downsample|.*ResNet_input|.*image|.*fc_output|.*CrossEntropyLoss\",\n", " \"train.save_interval\": \"100\",\n", " \"eval.save_interval\": \"1\",\n", " },\n", " )\n", " ]\n", ")" ] }, { "cell_type": "markdown", "id": "cc46eceb", "metadata": { "papermill": { "duration": 0.046495, "end_time": "2021-05-27T00:13:53.979056", "exception": false, "start_time": "2021-05-27T00:13:53.932561", "status": "completed" }, "tags": [] }, "source": [ "#### Builtin rule\n", "In addition we run the training job with a builtin rule. We select here the class imbalance rule that measures whether our training set is imbalanced and/or whether the model has lower accurcay for certain classes in the training dataset. The tensors that are passed into the loss function `CrossEntropyLoss` are the labels and predictions. In our example those tensors have the name `CrossEntropyLoss_input_1` and `CrossEntropyLoss_input_0`. The rule uses those tensors to compute class imbalance." ] }, { "cell_type": "code", "execution_count": null, "id": "dae7efb8", "metadata": { "execution": { "iopub.execute_input": "2021-05-27T00:13:54.076979Z", "iopub.status.busy": "2021-05-27T00:13:54.076281Z", "iopub.status.idle": "2021-05-27T00:13:54.078453Z", "shell.execute_reply": "2021-05-27T00:13:54.078889Z" }, "papermill": { "duration": 0.053794, "end_time": "2021-05-27T00:13:54.079066", "exception": false, "start_time": "2021-05-27T00:13:54.025272", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "from sagemaker.debugger import Rule, CollectionConfig, rule_configs\n", "\n", "class_imbalance_rule = Rule.sagemaker(\n", " base_config=rule_configs.class_imbalance(),\n", " rule_parameters={\n", " \"labels_regex\": \"CrossEntropyLoss_input_1\",\n", " \"predictions_regex\": \"CrossEntropyLoss_input_0\",\n", " \"argmax\": \"True\",\n", " },\n", ")" ] }, { "cell_type": "markdown", "id": "66463982", "metadata": { "papermill": { "duration": 0.046268, "end_time": "2021-05-27T00:13:54.172681", "exception": false, "start_time": "2021-05-27T00:13:54.126413", "status": "completed" }, "tags": [] }, "source": [ "### SageMaker training\n", "Following code defines the SageMaker estimator. The entry point script [train.py](entry_point/train.py) defines the model training. It downloads a pre-trained ResNet model and performs transfer learning on the German traffic sign dataset." ] }, { "cell_type": "code", "execution_count": null, "id": "bb5f7195", "metadata": { "execution": { "iopub.execute_input": "2021-05-27T00:13:54.272889Z", "iopub.status.busy": "2021-05-27T00:13:54.272329Z", "iopub.status.idle": "2021-05-27T00:13:54.782729Z", "shell.execute_reply": "2021-05-27T00:13:54.783119Z" }, "papermill": { "duration": 0.564467, "end_time": "2021-05-27T00:13:54.783251", "exception": false, "start_time": "2021-05-27T00:13:54.218784", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "from sagemaker.pytorch import PyTorch\n", "\n", "role = sagemaker.get_execution_role()\n", "\n", "pytorch_estimator = PyTorch(\n", " entry_point=\"train.py\",\n", " source_dir=\"entry_point\",\n", " role=role,\n", " train_instance_type=\"ml.p3.2xlarge\",\n", " train_instance_count=1,\n", " framework_version=\"1.12.0\",\n", " py_version=\"py38\",\n", " hyperparameters={\n", " \"epochs\": 5,\n", " \"batch_size_train\": 64,\n", " \"batch_size_val\": 4,\n", " \"learning_rate\": 0.001,\n", " },\n", " volume_size=100,\n", " debugger_hook_config=debugger_hook_config,\n", " rules=[class_imbalance_rule],\n", ")" ] }, { "cell_type": "markdown", "id": "f325ea72", "metadata": { "papermill": { "duration": 0.047051, "end_time": "2021-05-27T00:13:54.877620", "exception": false, "start_time": "2021-05-27T00:13:54.830569", "status": "completed" }, "tags": [] }, "source": [ "Now that we have defined the estimator we can call `fit`, which will start the training job on a `ml.p3.2xlarge` instance:" ] }, { "cell_type": "code", "execution_count": null, "id": "850935f1", "metadata": { "execution": { "iopub.execute_input": "2021-05-27T00:13:54.977013Z", "iopub.status.busy": "2021-05-27T00:13:54.976097Z", "iopub.status.idle": "2021-05-27T00:13:55.524590Z", "shell.execute_reply": "2021-05-27T00:13:55.524998Z" }, "papermill": { "duration": 0.600085, "end_time": "2021-05-27T00:13:55.525130", "exception": false, "start_time": "2021-05-27T00:13:54.925045", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "pytorch_estimator.fit(\n", " inputs={\"train\": \"s3://{}/train\".format(bucket), \"test\": \"s3://{}/test\".format(bucket)},\n", " wait=False,\n", ")" ] }, { "cell_type": "markdown", "id": "5fa609f0", "metadata": { "papermill": { "duration": 0.047711, "end_time": "2021-05-27T00:13:55.620890", "exception": false, "start_time": "2021-05-27T00:13:55.573179", "status": "completed" }, "tags": [] }, "source": [ "#### Check rule status" ] }, { "cell_type": "code", "execution_count": null, "id": "667543bf", "metadata": { "execution": { "iopub.execute_input": "2021-05-27T00:13:55.722556Z", "iopub.status.busy": "2021-05-27T00:13:55.722100Z", "iopub.status.idle": "2021-05-27T00:13:55.753817Z", "shell.execute_reply": "2021-05-27T00:13:55.754180Z" }, "papermill": { "duration": 0.085326, "end_time": "2021-05-27T00:13:55.754303", "exception": false, "start_time": "2021-05-27T00:13:55.668977", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "pytorch_estimator.latest_training_job.rule_job_summary()" ] }, { "cell_type": "markdown", "id": "80ae6f6a", "metadata": { "papermill": { "duration": 0.048474, "end_time": "2021-05-27T00:13:55.851374", "exception": false, "start_time": "2021-05-27T00:13:55.802900", "status": "completed" }, "tags": [] }, "source": [ "### Visualize saliency maps in real-time\n", "Once the training job has started, SageMaker debugger will upload the tensors of our model into S3. We can check the location in S3: \n" ] }, { "cell_type": "code", "execution_count": null, "id": "3df61990", "metadata": { "execution": { "iopub.execute_input": "2021-05-27T00:13:55.952168Z", "iopub.status.busy": "2021-05-27T00:13:55.951686Z", "iopub.status.idle": "2021-05-27T00:13:55.954064Z", "shell.execute_reply": "2021-05-27T00:13:55.954436Z" }, "papermill": { "duration": 0.054669, "end_time": "2021-05-27T00:13:55.954560", "exception": false, "start_time": "2021-05-27T00:13:55.899891", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "path = pytorch_estimator.latest_job_debugger_artifacts_path()\n", "print(\"Tensors are stored in: {}\".format(path))" ] }, { "cell_type": "markdown", "id": "5b8f740a", "metadata": { "papermill": { "duration": 0.049098, "end_time": "2021-05-27T00:13:56.053291", "exception": false, "start_time": "2021-05-27T00:13:56.004193", "status": "completed" }, "tags": [] }, "source": [ "We can check the status of our training job, by executing `describe_training_job`:" ] }, { "cell_type": "code", "execution_count": null, "id": "8a1e6bd0", "metadata": { "execution": { "iopub.execute_input": "2021-05-27T00:13:56.157269Z", "iopub.status.busy": "2021-05-27T00:13:56.156788Z", "iopub.status.idle": "2021-05-27T00:13:56.180422Z", "shell.execute_reply": "2021-05-27T00:13:56.180842Z" }, "papermill": { "duration": 0.078446, "end_time": "2021-05-27T00:13:56.180973", "exception": false, "start_time": "2021-05-27T00:13:56.102527", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "job_name = pytorch_estimator.latest_training_job.name\n", "print(\"Training job name: {}\".format(job_name))\n", "\n", "client = pytorch_estimator.sagemaker_session.sagemaker_client\n", "\n", "description = client.describe_training_job(TrainingJobName=job_name)" ] }, { "cell_type": "markdown", "id": "e0dd2da3", "metadata": { "papermill": { "duration": 0.049753, "end_time": "2021-05-27T00:13:56.280271", "exception": false, "start_time": "2021-05-27T00:13:56.230518", "status": "completed" }, "tags": [] }, "source": [ "We can access the tensors from S3 once the training job is in status `Training` or `Completed`. In the following code cell we check the job status:" ] }, { "cell_type": "code", "execution_count": null, "id": "712ae8f1", "metadata": { "execution": { "iopub.execute_input": "2021-05-27T00:13:56.386072Z", "iopub.status.busy": "2021-05-27T00:13:56.385548Z", "iopub.status.idle": "2021-05-27T00:18:23.918605Z", "shell.execute_reply": "2021-05-27T00:18:23.917712Z" }, "papermill": { "duration": 267.588794, "end_time": "2021-05-27T00:18:23.918823", "exception": true, "start_time": "2021-05-27T00:13:56.330029", "status": "failed" }, "tags": [] }, "outputs": [], "source": [ "import time\n", "\n", "if description[\"TrainingJobStatus\"] != \"Completed\":\n", " while description[\"SecondaryStatus\"] not in {\"Training\", \"Completed\"}:\n", " description = client.describe_training_job(TrainingJobName=job_name)\n", " primary_status = description[\"TrainingJobStatus\"]\n", " secondary_status = description[\"SecondaryStatus\"]\n", " print(\n", " \"Current job status: [PrimaryStatus: {}, SecondaryStatus: {}]\".format(\n", " primary_status, secondary_status\n", " )\n", " )\n", " time.sleep(30)" ] }, { "cell_type": "markdown", "id": "13376238", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "source": [ "Once the job is in status `Training` or `Completed`, we can create the trial: " ] }, { "cell_type": "code", "execution_count": null, "id": "c3b7e454", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [], "source": [ "from smdebug.trials import create_trial\n", "\n", "trial = create_trial(path)" ] }, { "cell_type": "markdown", "id": "405f4333", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "source": [ "Now we can compute the saliency maps. The method described in [Full-Gradient Representation for Neural Network Visualization [1]](https://arxiv.org/abs/1905.00780) requires all intermediate features and their biases. The following cell retrieves the gradients for the outputs of batchnorm and downsampling layers and the corresponding biases. If you use a model other than ResNet you may need to adjust the regular expressions in the following cell:" ] }, { "cell_type": "code", "execution_count": null, "id": "61380946", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [], "source": [ "biases, gradients = [], []\n", "\n", "for tname in trial.tensor_names(regex=\".*gradient.*bn.*output|.*gradient.*downsample.1.*output\"):\n", " gradients.append(tname)\n", "\n", "for tname in trial.tensor_names(regex=\"^(?=.*bias)(?:(?!fc).)*$\"):\n", " biases.append(tname)" ] }, { "cell_type": "markdown", "id": "c78df08d", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "source": [ "As mentioned in the beginning of the notebook, in the case of BatchNorm layers, we need to compute the implicit biases. In the following code cell we retrieve the necessary tensors:" ] }, { "cell_type": "code", "execution_count": null, "id": "36f704b3", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [], "source": [ "bn_weights, running_vars, running_means = [], [], []\n", "\n", "for tname in trial.tensor_names(regex=\".*running_mean\"):\n", " running_means.append(tname)\n", "\n", "for tname in trial.tensor_names(regex=\".*running_var\"):\n", " running_vars.append(tname)\n", "\n", "for tname in trial.tensor_names(regex=\".*bn.*weight|.*downsample.1.*weight\"):\n", " bn_weights.append(tname)" ] }, { "cell_type": "markdown", "id": "1e782c4e", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "source": [ "We need to ensure that the tensors in the list are in order, e.g. bias vector and gradients need to be for the same layer. Let's have a look on the tensors:" ] }, { "cell_type": "code", "execution_count": null, "id": "9658c59f", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [], "source": [ "for bias, gradient, weight, running_var, running_mean in zip(\n", " biases, gradients, bn_weights, running_vars, running_means\n", "):\n", " print(bias, gradient, weight, running_var, running_mean)" ] }, { "cell_type": "markdown", "id": "b924b590", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "source": [ "Here we define a helper function that is used later on to normalize tensors:" ] }, { "cell_type": "code", "execution_count": null, "id": "d8aeccaa", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [], "source": [ "def normalize(tensor):\n", " tensor = tensor - np.min(tensor)\n", " tensor = tensor / np.max(tensor)\n", " return tensor" ] }, { "cell_type": "markdown", "id": "b6443168", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "source": [ "A helper function to plot saliency maps:" ] }, { "cell_type": "code", "execution_count": null, "id": "9fda3d75", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "\n", "\n", "def plot(saliency_map, image, predicted_class, propability):\n", "\n", " # clear matplotlib figure\n", " plt.clf()\n", "\n", " # revert normalization\n", " mean = [[[0.485]], [[0.456]], [[0.406]]]\n", " std = [[[0.229]], [[0.224]], [[0.225]]]\n", " image = image * std + mean\n", "\n", " # transpose image: color channel in last dimension\n", " image = image.transpose(1, 2, 0)\n", " image = (image * 255).astype(np.uint8)\n", "\n", " # create heatmap: we multiply it with -1 because we use\n", " # matplotlib to plot output results which inverts the colormap\n", " saliency_map = -saliency_map * 255\n", " saliency_map = saliency_map.astype(np.uint8)\n", " heatmap = cv2.applyColorMap(saliency_map, cv2.COLORMAP_JET)\n", "\n", " # overlay original image with heatmap\n", " output_image = heatmap.astype(np.float32) + image.astype(np.float32)\n", "\n", " # normalize\n", " output_image = output_image / np.max(output_image)\n", "\n", " # plot\n", " fig, (ax0, ax1) = plt.subplots(ncols=2, figsize=(10, 5))\n", " ax0.imshow(image)\n", " ax1.imshow(output_image)\n", " ax0.set_axis_off()\n", " ax1.set_axis_off()\n", " ax0.set_title(\"Input image\")\n", " ax1.set_title(\"Predicted class \" + predicted_class + \" with propability \" + propability + \"%\")\n", " plt.show()" ] }, { "cell_type": "markdown", "id": "afd3a90c", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "source": [ "A helper function to compute implicit biases:" ] }, { "cell_type": "code", "execution_count": null, "id": "863d21e9", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [], "source": [ "def compute_implicit_biases(bn_weights, running_vars, running_means, step):\n", " implicit_biases = []\n", " for weight_name, running_var_name, running_mean_name in zip(\n", " bn_weights, running_vars, running_means\n", " ):\n", " weight = trial.tensor(weight_name).value(step_num=step, mode=modes.EVAL)\n", " running_var = trial.tensor(running_var_name).value(step_num=step, mode=modes.EVAL)\n", " running_mean = trial.tensor(running_mean_name).value(step_num=step, mode=modes.EVAL)\n", " implicit_biases.append(-running_mean / np.sqrt(running_var) * weight)\n", " return implicit_biases" ] }, { "cell_type": "markdown", "id": "aae01d2e", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "source": [ "Get available steps:" ] }, { "cell_type": "code", "execution_count": null, "id": "de130631", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [], "source": [ "import time\n", "\n", "steps = 0\n", "while steps == 0:\n", " steps = trial.steps()\n", " print(\"Waiting for tensors to become available...\")\n", " time.sleep(3)\n", "print(\"\\nDone\")\n", "\n", "print(\"Getting tensors...\")\n", "rendered_steps = []" ] }, { "cell_type": "markdown", "id": "e45aca97", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "source": [ "We iterate over the tensors from the validation steps and compute the saliency map for each item in the batch. To compute the saliency map, we perform the following steps:\n", "\n", "1. compute the implicit bias\n", "2. multiply gradients and bias (sum of explicit and implicit bias)\n", "3. normalize result \n", "4. interpolate tensor to the input size of the original input image\n", "5. create heatmap and overlay it with the original input image" ] }, { "cell_type": "code", "execution_count": null, "id": "cfd66f2e", "metadata": { "papermill": { "duration": null, "end_time": null, "exception": null, "start_time": null, "status": "pending" }, "tags": [] }, "outputs": [], "source": [ "import numpy as np\n", "import cv2\n", "import scipy.ndimage\n", "import scipy.special\n", "from smdebug import modes\n", "from smdebug.core.modes import ModeKeys\n", "from smdebug.exceptions import TensorUnavailableForStep\n", "import os\n", "\n", "image_size = 224\n", "\n", "loaded_all_steps = False\n", "\n", "while not loaded_all_steps and description[\"SecondaryStatus\"] != \"Completed\":\n", "\n", " # get available steps\n", " loaded_all_steps = trial.loaded_all_steps\n", " steps = trial.steps(mode=modes.EVAL)\n", "\n", " # quick way to get diff between two lists\n", " steps_to_render = list(set(steps).symmetric_difference(set(rendered_steps)))\n", "\n", " # iterate over available steps\n", " for step in sorted(steps_to_render):\n", " try:\n", "\n", " # get original input image\n", " image_batch = trial.tensor(\"ResNet_input_0\").value(step_num=step, mode=modes.EVAL)\n", "\n", " # compute implicit biases from batchnorm layers\n", " implicit_biases = compute_implicit_biases(bn_weights, running_vars, running_means, step)\n", "\n", " for item in range(image_batch.shape[0]):\n", "\n", " # input image\n", " image = image_batch[item, :, :, :]\n", "\n", " # get gradients of input image\n", " image_gradient = trial.tensor(\"gradient/image\").value(\n", " step_num=step, mode=modes.EVAL\n", " )[item, :]\n", " image_gradient = np.sum(normalize(np.abs(image_gradient * image)), axis=0)\n", " saliency_map = image_gradient\n", "\n", " for gradient_name, bias_name, implicit_bias in zip(\n", " gradients, biases, implicit_biases\n", " ):\n", "\n", " # get gradients and bias vectors for corresponding step\n", " gradient = trial.tensor(gradient_name).value(step_num=step, mode=modes.EVAL)[\n", " item : item + 1, :, :, :\n", " ]\n", " bias = trial.tensor(bias_name).value(step_num=step, mode=modes.EVAL)\n", " bias = bias + implicit_bias\n", "\n", " # compute full gradient\n", " bias = bias.reshape((1, bias.shape[0], 1, 1))\n", " bias = np.broadcast_to(bias, gradient.shape)\n", " bias_gradient = normalize(np.abs(bias * gradient))\n", "\n", " # interpolate to original image size\n", " for channel in range(bias_gradient.shape[1]):\n", " interpolated = scipy.ndimage.zoom(\n", " bias_gradient[0, channel, :, :],\n", " image_size / bias_gradient.shape[2],\n", " order=1,\n", " )\n", " saliency_map += interpolated\n", "\n", " # normalize\n", " saliency_map = normalize(saliency_map)\n", "\n", " # predicted class and propability\n", " predicted_class = trial.tensor(\"fc_output_0\").value(step_num=step, mode=modes.EVAL)[\n", " item, :\n", " ]\n", " print(\"Predicted class:\", np.argmax(predicted_class))\n", " scores = np.exp(np.asarray(predicted_class))\n", " scores = scores / scores.sum(0)\n", "\n", " # plot image and heatmap\n", " plot(\n", " saliency_map,\n", " image,\n", " str(np.argmax(predicted_class)),\n", " str(int(np.max(scores) * 100)),\n", " )\n", "\n", " except TensorUnavailableForStep:\n", " print(\"Tensor unavailable for step {}\".format(step))\n", "\n", " rendered_steps.extend(steps_to_render)\n", "\n", " time.sleep(5)\n", "\n", " description = client.describe_training_job(TrainingJobName=job_name)\n", "\n", "print(\"\\nDone\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## Notebook CI Test Results\n", "\n", "This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.\n", "\n", "![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-1/sagemaker-debugger|model_specific_realtime_analysis|cnn_class_activation_maps|cnn_class_activation_maps.ipynb)\n", "\n", "![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-east-2/sagemaker-debugger|model_specific_realtime_analysis|cnn_class_activation_maps|cnn_class_activation_maps.ipynb)\n", "\n", "![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/us-west-1/sagemaker-debugger|model_specific_realtime_analysis|cnn_class_activation_maps|cnn_class_activation_maps.ipynb)\n", "\n", "![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ca-central-1/sagemaker-debugger|model_specific_realtime_analysis|cnn_class_activation_maps|cnn_class_activation_maps.ipynb)\n", "\n", "![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/sa-east-1/sagemaker-debugger|model_specific_realtime_analysis|cnn_class_activation_maps|cnn_class_activation_maps.ipynb)\n", "\n", "![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-1/sagemaker-debugger|model_specific_realtime_analysis|cnn_class_activation_maps|cnn_class_activation_maps.ipynb)\n", "\n", "![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-2/sagemaker-debugger|model_specific_realtime_analysis|cnn_class_activation_maps|cnn_class_activation_maps.ipynb)\n", "\n", "![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-west-3/sagemaker-debugger|model_specific_realtime_analysis|cnn_class_activation_maps|cnn_class_activation_maps.ipynb)\n", "\n", "![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-central-1/sagemaker-debugger|model_specific_realtime_analysis|cnn_class_activation_maps|cnn_class_activation_maps.ipynb)\n", "\n", "![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/eu-north-1/sagemaker-debugger|model_specific_realtime_analysis|cnn_class_activation_maps|cnn_class_activation_maps.ipynb)\n", "\n", "![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-1/sagemaker-debugger|model_specific_realtime_analysis|cnn_class_activation_maps|cnn_class_activation_maps.ipynb)\n", "\n", "![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-southeast-2/sagemaker-debugger|model_specific_realtime_analysis|cnn_class_activation_maps|cnn_class_activation_maps.ipynb)\n", "\n", "![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-1/sagemaker-debugger|model_specific_realtime_analysis|cnn_class_activation_maps|cnn_class_activation_maps.ipynb)\n", "\n", "![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-northeast-2/sagemaker-debugger|model_specific_realtime_analysis|cnn_class_activation_maps|cnn_class_activation_maps.ipynb)\n", "\n", "![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://h75twx4l60.execute-api.us-west-2.amazonaws.com/sagemaker-nb/ap-south-1/sagemaker-debugger|model_specific_realtime_analysis|cnn_class_activation_maps|cnn_class_activation_maps.ipynb)\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.11 64-bit ('3.8.11')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.11" }, "papermill": { "default_parameters": {}, "duration": 310.597031, "end_time": "2021-05-27T00:18:24.478429", "environment_variables": {}, "exception": true, "input_path": "cnn_class_activation_maps.ipynb", "output_path": "/opt/ml/processing/output/cnn_class_activation_maps-2021-05-27-00-08-39.ipynb", "parameters": { "kms_key": "arn:aws:kms:us-west-2:521695447989:key/6e9984db-50cf-4c7e-926c-877ec47a8b25" }, "start_time": "2021-05-27T00:13:13.881398", "version": "2.3.3" } }, "nbformat": 4, "nbformat_minor": 5 }