tutorials/Distributed_Attribution.ipynb (567 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Distributed Computation of Attributions using Captum" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this tutorial, we provide some examples of using Captum with the torch.distributed package and DataParallel, allowing computing attributions in a distributed manner across processors, machines or GPUs.\n", "\n", "In the first part of this tutorial, we demonstrate dividing a single batch of inputs and computing attributions for each part of the batch in a separate process or GPU if available using torch.distributed and DataParallel. In the second part of this tutorial, we demonstrate computing attributions over the Titanic dataset in a distributed manner, dividing the dataset among processes and computing the global average attribution." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 1: Distributing computation of Integrated Gradients for an input batch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this part, our goal is to distribute a small batch of input examples across processes, compute the attributions independently on each process, and collect the resulting attributions. This approach can be very helpful for algorithms such as IntegratedGradients, which internally expand the input, since they can be performed with a larger number of steps when inputs are distributed across devices." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will first demonstrate this with torch.distributed and then demonstrate the same computation with DataParallel, which is particularly for distribution across GPUs." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Initial imports:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.distributed as dist\n", "from torch import Tensor\n", "from torch.multiprocessing import Process\n", "\n", "from captum.attr import IntegratedGradients" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now define a small toy model for this example." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class ToyModel(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.linear1 = nn.Linear(3, 4)\n", " self.linear1.weight = nn.Parameter(torch.ones(4, 3))\n", " self.linear1.bias = nn.Parameter(torch.tensor([-10.0, 1.0, 1.0, 1.0]))\n", " self.relu = nn.ReLU()\n", "\n", " def forward(self, x: Tensor):\n", " return self.relu(self.linear1(x))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### torch.distributed Example" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the following cell, we set parameters USE_CUDA and WORLD_SIZE. WORLD_SIZE corresponds to the number of processes initialized and should be set to either 1, 2, or 4 for this example. USE_CUDA should be set to true if GPUs are available and there must be at least WORLD_SIZE GPUs available." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "USE_CUDA = True\n", "WORLD_SIZE = 4" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now define the function that runs on each process, which takes the rank (identifier for current process), size (total number of processes), and inp_batch, which corresponds to the input portion for the current process. Integrated Gradients is computed on the given input and concatenated with other processes on the process with rank 0. The model can also be wrapped in Distributed Data Parallel, which synchronizes parameter updates across processes, by uncommenting the corresponding line, but it is not necessary for this example, since no parameters updates / training is conducted." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "#Uncomment the following import and corresponding line in run to test with DistributedDataParallel\n", "from torch.nn.parallel import DistributedDataParallel as DDP\n", "\n", "def run(rank, size, inp_batch):\n", " # Initialize model\n", " model = ToyModel()\n", " \n", " # Move model and input to device with ID rank if USE_CUDA is True\n", " if USE_CUDA:\n", " inp_batch = inp_batch.cuda(rank)\n", " model = model.cuda(rank)\n", " # Uncomment line below to wrap with DistributedDataParallel\n", " model = DDP(model, device_ids=[rank])\n", "\n", " # Create IG object and compute attributions.\n", " ig = IntegratedGradients(model)\n", " attr = ig.attribute(inp_batch, target=0)\n", " \n", " # Combine attributions from each device using distributed.gather\n", " # Rank 0 process gathers all attributions, each other process\n", " # sends its corresponding attribution.\n", " if rank == 0:\n", " output_list = [torch.zeros_like(attr) for _ in range(size)]\n", " torch.distributed.gather(attr, gather_list=output_list)\n", " combined_attr = torch.cat(output_list)\n", " # Rank 0 prints the combined attribution tensor after gathering\n", " print(combined_attr)\n", " else:\n", " torch.distributed.gather(attr)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This function performs required setup and cleanup steps on each process and executes the chosen function (run)." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "def init_process(rank, size, fn, inp_batch, backend='gloo'):\n", " \"\"\" Initialize the distributed environment. \"\"\"\n", " os.environ['MASTER_ADDR'] = '127.0.0.1'\n", " os.environ['MASTER_PORT'] = '29500'\n", " dist.init_process_group(backend, rank=rank, world_size=size)\n", " fn(rank, size, inp_batch)\n", " dist.destroy_process_group()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We are now ready to run the initialize and run the processes. The gathered output attributions are printed by the rank 0 process upon completion." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[0.0000, 0.0000, 0.0000],\n", " [0.4813, 0.6417, 0.8021],\n", " [3.1865, 3.7176, 4.2487],\n", " [5.8774, 6.5305, 7.1835]], device='cuda:0')\n" ] } ], "source": [ "size = WORLD_SIZE\n", "processes = []\n", "batch = 1.0 * torch.arange(12).reshape(4,3)\n", "batch_chunks = batch.chunk(size)\n", "for rank in range(size):\n", " p = Process(target=init_process, args=(rank, size, run, batch_chunks[rank]))\n", " p.start()\n", " processes.append(p)\n", "\n", "for p in processes:\n", " p.join()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To confirm the correctness of the attributions, we can compute the same attributions from a single process and confirm the results match." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[0.0000, 0.0000, 0.0000],\n", " [0.4813, 0.6417, 0.8021],\n", " [3.1865, 3.7176, 4.2487],\n", " [5.8774, 6.5305, 7.1835]])\n" ] } ], "source": [ "model = ToyModel()\n", "ig = IntegratedGradients(model)\n", "\n", "batch = 1.0 * torch.arange(12).reshape(4,3)\n", "print(ig.attribute(batch, target=0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### DataParallel Example" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If GPUs are available, we can also distribute computation using torch.nn.DataParallel instead. DataParallel is a wrapper around a module which internally splits each input batch across available CUDA device, parallelizing computation. Note that DistributedDataParallel is expected to be faster than DataParallel, but DataParallel can be simpler to setup, with only a wrapper around the module. More information regarding comparing the 2 approaches can be found [here](https://pytorch.org/tutorials/intermediate/ddp_tutorial.html)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The same attributions can be computed using DataParallel as follows. Note that this can only be run when CUDA is available." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor([[0.0000, 0.0000, 0.0000],\n", " [0.4813, 0.6417, 0.8021],\n", " [3.1865, 3.7176, 4.2487],\n", " [5.8774, 6.5305, 7.1835]], device='cuda:0')\n" ] } ], "source": [ "dp_model = nn.DataParallel(model.cuda())\n", "ig = IntegratedGradients(dp_model)\n", "\n", "print(ig.attribute(batch.cuda(), target=0))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Part 2: Distributing computation of Titanic Dataset Attribution" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In this part, our goal is to distribute a small batch of input examples across processes, compute the attributions independently on each process, and collect the resulting attributions. For this part, make sure that pandas is installed and available.\n", "\n", "NOTE: Please restart your kernel before executing this portion, due to issues with mutliprocessing from Jupyter notebooks.\n", "\n", "Initial Imports:" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import os\n", "import pandas as pd\n", "import numpy as np\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.distributed as dist\n", "from torch.utils.data import TensorDataset, DataLoader\n", "from torch.utils.data.distributed import DistributedSampler\n", "from torch.multiprocessing import Process\n", "\n", "from captum.attr import IntegratedGradients" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Download the Titanic dataset from: http://biostat.mc.vanderbilt.edu/wiki/pub/Main/DataSets/titanic3.csv. \n", "Update path to the dataset here." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "dataset_path = \"titanic3.csv\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We define a simple neural network architecture, which is trained in the Titanic tutorial." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "class TitanicSimpleNNModel(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.linear1 = nn.Linear(12, 12)\n", " self.sigmoid1 = nn.Sigmoid()\n", " self.linear2 = nn.Linear(12, 8)\n", " self.sigmoid2 = nn.Sigmoid()\n", " self.linear3 = nn.Linear(8, 2)\n", " self.softmax = nn.Softmax(dim=1)\n", "\n", " def forward(self, x):\n", " lin1_out = self.linear1(x)\n", " sigmoid_out1 = self.sigmoid1(lin1_out)\n", " sigmoid_out2 = self.sigmoid2(self.linear2(sigmoid_out1))\n", " return self.softmax(self.linear3(sigmoid_out2))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now define a helper method to read the CSV and generate a TensorDataset object corresponding to the test set of the Titianic dataset. For more details on the pre-processing, refer to the Titanic_Basic_Interpret tutorial." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# Read dataset from csv file.\n", "def load_dataset():\n", " titanic_data = pd.read_csv(dataset_path)\n", " titanic_data = pd.concat([titanic_data,\n", " pd.get_dummies(titanic_data['sex']),\n", " pd.get_dummies(titanic_data['embarked'],prefix=\"embark\"),\n", " pd.get_dummies(titanic_data['pclass'],prefix=\"class\")], axis=1)\n", " titanic_data[\"age\"] = titanic_data[\"age\"].fillna(titanic_data[\"age\"].mean())\n", " titanic_data[\"fare\"] = titanic_data[\"fare\"].fillna(titanic_data[\"fare\"].mean())\n", " titanic_data = titanic_data.drop(['name','ticket','cabin','boat','body','home.dest','sex','embarked','pclass'], axis=1)\n", " # Set random seed for reproducibility.\n", " np.random.seed(131254)\n", "\n", " # Convert features and labels to numpy arrays.\n", " labels = titanic_data[\"survived\"].to_numpy()\n", " titanic_data = titanic_data.drop(['survived'], axis=1)\n", " feature_names = list(titanic_data.columns)\n", " data = titanic_data.to_numpy()\n", "\n", " # Separate training and test sets using \n", " train_indices = np.random.choice(len(labels), int(0.7*len(labels)), replace=False)\n", " test_indices = list(set(range(len(labels))) - set(train_indices))\n", "\n", " test_features = data[test_indices]\n", " test_features_tensor = torch.from_numpy(test_features).type(torch.FloatTensor)\n", " dataset = TensorDataset(test_features_tensor)\n", " return dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the following cell, we set parameters USE_CUDA and WORLD_SIZE. WORLD_SIZE corresponds to the number of processes initialized. USE_CUDA should be set to true if GPUs are available and there must be at least WORLD_SIZE GPUs available." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "USE_CUDA = True\n", "WORLD_SIZE = 4" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We now define the function that runs on each process, which takes the rank (identifier for current process) and size (total number of processes). The model and appropriate part of the dataset are loaded, and attributions are computed for this part of the dataset. The attributions are then averaged across processes. Note that DistributedSampler repeats examples to ensure that each partition has the same number of examples.\n", "\n", "Note that this method loads a pretrained Titanic model, which can be downloaded from here: https://github.com/pytorch/captum/blob/master/tutorials/models/titanic_model.pt . Alternatively, the model can be trained from scratch from the Titanic_Basic_Interpret tutorial." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "def run(rank, size):\n", " # Load Dataset\n", " dataset = load_dataset()\n", " \n", " # Create TitanicSimpleNNModel and load saved weights.\n", " net = TitanicSimpleNNModel()\n", " net.load_state_dict(torch.load('models/titanic_model.pt'))\n", " \n", " # Create sampler which divides dataset among processes.\n", " sampler = DistributedSampler(dataset,num_replicas=size, rank=rank, shuffle=False)\n", " loader = DataLoader(dataset, batch_size=50, sampler=sampler)\n", " \n", " # If USE_CUDA, move model to CUDA device with id rank.\n", " if USE_CUDA:\n", " net = net.cuda(rank)\n", " \n", " # Initialize IG object\n", " ig = IntegratedGradients(net)\n", " \n", " # Compute total attribution\n", " total_attr = 0\n", " for batch in loader:\n", " inp = batch[0]\n", " if USE_CUDA:\n", " inp = inp.cuda(rank)\n", " attr = ig.attribute(inp, target=1)\n", " total_attr += attr.sum(dim=0)\n", " \n", " # Divide by number of examples in partition\n", " total_attr /= len(sampler)\n", " \n", " # Sum average attributions from each process on rank 0.\n", " torch.distributed.reduce(total_attr, dst=0)\n", " if rank == 0:\n", " # Average across processes, since each partition has same number of examples.\n", " total_attr = total_attr / size\n", " print(\"Average Attributions:\", total_attr)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This function performs required setup and cleanup steps on each process and executes the chosen function (run)." ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "def init_process(rank, size, fn, backend='gloo'):\n", " \"\"\" Initialize the distributed environment. \"\"\"\n", " os.environ['MASTER_ADDR'] = '127.0.0.1'\n", " os.environ['MASTER_PORT'] = '29500'\n", " dist.init_process_group(backend, rank=rank, world_size=size)\n", " fn(rank, size)\n", " dist.destroy_process_group()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We are now ready to run the initialize and run the processes. The average attributions over the dataset are printed by the rank 0 process upon completion." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Average Attributions: tensor([-0.4516, -0.1182, -0.0551, 0.1736, 0.1527, -0.3588, 0.0861, -0.0005,\n", " -0.0811, 0.0622, 0.0207, -0.1581], device='cuda:0')\n" ] } ], "source": [ "size = WORLD_SIZE\n", "processes = []\n", "\n", "for rank in range(size):\n", " p = Process(target=init_process, args=(rank, size, run))\n", " p.start()\n", " processes.append(p)\n", "\n", "for p in processes:\n", " p.join()\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.7" } }, "nbformat": 4, "nbformat_minor": 4 }