in tinynn/graph/modifier.py [0:0]
def modify_input(self, remove_idx):
bn = self.node.module
preserve_idx = complementary_list([i for i in range(self.weight_mask["weight"].shape[0])], remove_idx)
if bn.weight.shape[0] != len(preserve_idx):
while self.graph_modifier.bn_compensation:
if len(self.next_modifiers()) != 1:
break
if len(self.next_modifiers()[0].next_modifiers()) != 1:
break
act = self.next_modifiers()[0].module()
conv = self.next_modifiers()[0].next_modifiers()[0].module()
if isinstance(act, nn.Module):
if type(act) not in [
nn.ReLU,
nn.ReLU6,
nn.LeakyReLU,
nn.Sigmoid,
nn.Tanh,
nn.Hardsigmoid,
nn.Hardtanh,
nn.Hardswish,
nn.LogSigmoid,
]:
log.debug(f"unsupported activation for bn compensation: {type(act)}")
break
if isinstance(act, TraceNode):
if self.next_modifiers()[0].node.kind() not in [
'relu',
'relu6',
'leaky_relu',
'sigmoid',
'tanh',
'hardsigmoid',
'hardtanh',
'hardswish',
'logsigmoid',
]:
log.debug(f"unsupported activation for bn compensation: {self.next_modifiers()[0].node.kind()}")
break
if type(conv) is not torch.nn.Conv2d:
break
if conv.groups == 1:
with torch.no_grad():
bias = torch.tensor(bn.bias)
activation_bias = act(bias)
fuse_weight = torch.sum(conv.weight, dim=[2, 3])
bn_bias = fuse_weight * activation_bias
bn_bias = bn_bias[:, [True if i in remove_idx else False for i in range(bn_bias.shape[1])]]
bn_bias = torch.sum(bn_bias, dim=[1])
if conv.bias is None:
conv.bias = torch.nn.Parameter(bn_bias)
else:
conv.bias = torch.nn.Parameter(conv.bias + bn_bias)
break
if isinstance(bn, nn.BatchNorm1d) or isinstance(bn, nn.BatchNorm2d):
log.info(f'[BN] {self.unique_name()}: channel {bn.num_features} -> {len(preserve_idx)}')
bn.register_buffer('running_mean', bn.running_mean[preserve_idx])
bn.register_buffer('running_var', bn.running_var[preserve_idx])
bn.num_batches_tracked = bn.num_batches_tracked.zero_()
bn.num_features = len(preserve_idx)
elif isinstance(bn, nn.LayerNorm):
if len(bn.normalized_shape) == 1:
log.info(f'[LN] {self.unique_name()}: channel {bn.normalized_shape} -> ({len(preserve_idx)},)')
bn.normalized_shape = (len(preserve_idx),)
else:
log.error("The Layer Normalization (LN) Modifier supports only one-dimensional normalized_shape.")
else:
log.error("Unsupported Norm Type")
bn.weight = torch.nn.Parameter(bn.weight[preserve_idx])
bn.bias = torch.nn.Parameter(bn.bias[preserve_idx])