def modify_input()

in tinynn/graph/modifier.py [0:0]


    def modify_input(self, remove_idx):
        bn = self.node.module
        preserve_idx = complementary_list([i for i in range(self.weight_mask["weight"].shape[0])], remove_idx)

        if bn.weight.shape[0] != len(preserve_idx):
            while self.graph_modifier.bn_compensation:
                if len(self.next_modifiers()) != 1:
                    break

                if len(self.next_modifiers()[0].next_modifiers()) != 1:
                    break

                act = self.next_modifiers()[0].module()
                conv = self.next_modifiers()[0].next_modifiers()[0].module()

                if isinstance(act, nn.Module):
                    if type(act) not in [
                        nn.ReLU,
                        nn.ReLU6,
                        nn.LeakyReLU,
                        nn.Sigmoid,
                        nn.Tanh,
                        nn.Hardsigmoid,
                        nn.Hardtanh,
                        nn.Hardswish,
                        nn.LogSigmoid,
                    ]:
                        log.debug(f"unsupported activation for bn compensation: {type(act)}")
                        break

                if isinstance(act, TraceNode):
                    if self.next_modifiers()[0].node.kind() not in [
                        'relu',
                        'relu6',
                        'leaky_relu',
                        'sigmoid',
                        'tanh',
                        'hardsigmoid',
                        'hardtanh',
                        'hardswish',
                        'logsigmoid',
                    ]:
                        log.debug(f"unsupported activation for bn compensation: {self.next_modifiers()[0].node.kind()}")
                        break

                if type(conv) is not torch.nn.Conv2d:
                    break

                if conv.groups == 1:
                    with torch.no_grad():
                        bias = torch.tensor(bn.bias)
                        activation_bias = act(bias)
                        fuse_weight = torch.sum(conv.weight, dim=[2, 3])
                        bn_bias = fuse_weight * activation_bias
                        bn_bias = bn_bias[:, [True if i in remove_idx else False for i in range(bn_bias.shape[1])]]
                        bn_bias = torch.sum(bn_bias, dim=[1])
                        if conv.bias is None:
                            conv.bias = torch.nn.Parameter(bn_bias)
                        else:
                            conv.bias = torch.nn.Parameter(conv.bias + bn_bias)
                break

            if isinstance(bn, nn.BatchNorm1d) or isinstance(bn, nn.BatchNorm2d):
                log.info(f'[BN] {self.unique_name()}: channel {bn.num_features} -> {len(preserve_idx)}')
                bn.register_buffer('running_mean', bn.running_mean[preserve_idx])
                bn.register_buffer('running_var', bn.running_var[preserve_idx])
                bn.num_batches_tracked = bn.num_batches_tracked.zero_()
                bn.num_features = len(preserve_idx)
            elif isinstance(bn, nn.LayerNorm):
                if len(bn.normalized_shape) == 1:
                    log.info(f'[LN] {self.unique_name()}: channel {bn.normalized_shape} -> ({len(preserve_idx)},)')
                    bn.normalized_shape = (len(preserve_idx),)
                else:
                    log.error("The Layer Normalization (LN) Modifier supports only one-dimensional normalized_shape.")
            else:
                log.error("Unsupported Norm Type")

            bn.weight = torch.nn.Parameter(bn.weight[preserve_idx])
            bn.bias = torch.nn.Parameter(bn.bias[preserve_idx])