# Network Pruning

Network pruning is a commonly-used technique to speed up your model during inference. We will talk about this topic in this tutorial.

## Basic concept
As we all know, the majority of the runtime is attributed to the generic matrix multiply (a.k.a. GEMM) operations. So naturally, the problem comes out that whether we can speed up the operation by reducing the number of the elements in the matrices. By setting the weights, biases and the corresponding input and output items to 0, we can then just skip those calculations.

There are generally two kinds of pruning, structured pruning and unstructured pruning. For structured pruning, the weight connections are removed in groups. e.g. The entire channel is deleted. It has the effect of changing the input and output shapes of layers and the weight matrices. Because of this, nearly every system can benefit from it. Unstructured pruning, on the other hand, removes individual weight connections from a network by setting them to 0. So, it is highly dependent on the inference backends. 

Currently, only structured pruning is supported in TinyNeuralNetwork.

### How structured pruning is implemented in DNN frameworks?
```py
model = Net(pretrained=True)
sparsity = 0.5

masks = {None: None}

def register_masks(layer):
    parent_layer = get_parent(layer)
    input_mask = masks[parent_layer]
    if is_passthrough_layer(layer):
        output_mask = input_mask
    else:
        output_mask = get_mask(layer, sparsity)
        register_mask(layer, input_mask, output_mask)
    masks[layer] = output_mask

model.apply(register_masks)
model.fit(train_data)

def apply_masks(layer):
    parent_layer = get_parent(layer)
    input_mask = masks[parent_layer]
    output_mask = masks[layer]
    apply_mask(layer, input_mask, output_mask)

model.apply(apply_masks)
```

### Network Pruning in TinyNerualNetwork
The problem in the previous code example is that only one parent layer is expected. But in some recent DNN models, there are a few complicated operations like `cat`, `add` and `split`. We need to resolve the dependencies of those operations as well.

To solve the aforementioned problem, first we go through some basic definitions. When the input shape and output shape of a node are not related during pruning, it is called a node with isolation. For example, the `conv`, `linear` and `lstm` nodes are nodes with isolation. We want to find out a group of nodes, which is called a subgraph, that starts with and ends with nodes with isolation and doesn't contain a subgraph in it. We use the nodes with isolation for finding out the candidate subgraphs in the model. 

```py
def find_subgraph(layer, input_modify, output_modify, nodes):
    if layer in nodes:
        return None

    nodes.append(layer)

    if is_layer_with_isolation(layer):
        if input_modify:
            for prev_layer in get_parent(layer):
                return find_subgraph(prev_layer, False, True, nodes)
        if output_modify:
            for next_layer in get_child(layer):
                return find_subgraph(next_layer, True, False, nodes)
    else:
        for prev_layer in get_parent(layer):
            return find_subgraph(prev_layer, input_modify, output_modify, nodes)
        for next_layer in get_child(layer):
            return find_subgraph(next_layer, input_modify, output_modify, nodes)

candidate_subgraphs = []

def construct_candidate_subgraphs(layer):
    if is_layer_with_isolation(layer):
        nodes = []
        find_subgraph(layer, True, False, nodes)
        candidate_subgraphs.append(nodes)

        nodes = []
        find_subgraph(layer, False, True, nodes)
        candidate_subgraphs.append(nodes)

model.apply(construct_subgraphs)
```

With all candidate subgraphs, the next step we do is to remove the duplicated and invalid ones in them. Due to space limitations, we will not cover this section in detail. When we get the final subgraphs, the first node in it is called the center node. During configuration, we use the name of the center node to represent the subgraph it constructs. Some properties can be set at the subgraph level by the user, like sparsity.

Although we have the subgraphs, the mapping of channels between nodes is still unknown. So we need to resolve channel dependency. Similarly, we pass the channel information recursively so as to get the correct mapping at each node. It may be a bit more complicated since each node has its own logic for sharing channel mapping. Operations like `add` require shared mapping in all the input and output tensors, while `cat` allows the inputs to have independent mappings, however the output mapping and the combined input mapping is shared. As this is too detailed, we will not expand on it.

After resolving the channel dependency, we follow the ordinary pruning process, that is to register the masks of the weight and bias tensors. And then you may just finetune the model. When the training process is finished, then it is time to apply the masks, so that the model actually gets smaller. Alternatively, you may apply the masks just after registering them if the masks won't change during training. As a result, the training process will be significantly faster. That's all the story for pruning.

### Using the pruner in TinyNeuralNetwork
It is really simple to use the pruner in our framework. You can use the code below.


In [4]:
import sys
sys.path.append('../..')

import torch
import torchvision

from tinynn.prune.oneshot_pruner import OneShotChannelPruner

model = torchvision.models.mobilenet_v2(pretrained=True)
model.train()

dummy_input = torch.randn(1, 3, 224, 224)

pruner = OneShotChannelPruner(model, dummy_input, config={'sparsity': 0.25, 'metrics': 'l2_norm'})

st_flops = pruner.calc_flops()
pruner.prune()

ed_flops = pruner.calc_flops()
print(f"Pruning over, reduced FLOPS {100 * (st_flops - ed_flops) / st_flops:.2f}%  ({st_flops} -> {ed_flops})")

# You should start finetuning the model here

INFO (tinynn.graph.modifier) [CONV] features_0_0: output 32 -> 24
INFO (tinynn.graph.modifier) [BN] features_0_1: channel 32 -> 24
INFO (tinynn.graph.modifier) [DW_CONV] features_1_conv_0_0: input 32 -> 24
INFO (tinynn.graph.modifier) [BN] features_1_conv_0_1: channel 32 -> 24
INFO (tinynn.graph.modifier) [CONV] features_1_conv_1: input 32 -> 24
INFO (tinynn.graph.modifier) [CONV] features_1_conv_1: output 16 -> 12
INFO (tinynn.graph.modifier) [BN] features_1_conv_2: channel 16 -> 12
INFO (tinynn.graph.modifier) [CONV] features_2_conv_0_0: input 16 -> 12
INFO (tinynn.graph.modifier) [CONV] features_2_conv_0_0: output 96 -> 72
INFO (tinynn.graph.modifier) [BN] features_2_conv_0_1: channel 96 -> 72
INFO (tinynn.graph.modifier) [DW_CONV] features_2_conv_1_0: input 96 -> 72
INFO (tinynn.graph.modifier) [BN] features_2_conv_1_1: channel 96 -> 72
INFO (tinynn.graph.modifier) [CONV] features_2_conv_2: input 96 -> 72
INFO (tinynn.graph.modifier) [CONV] features_2_conv_2: output 24 -> 18
INFO (

Pruning over, reduced FLOPS 40.99%  (314130496 -> 185359152)
