def replace_conv2d()

in nni/compression/pytorch/speedup/compress_modules.py [0:0]


def replace_conv2d(conv, masks):
    """
    Replace the original conv with a new one according to the infered
    masks, the function support the fine-grained sparsity and coarse-grained
    sparsity. In the fine-grained scenario, this replace function will replace
    the filters that happen to be totally coverd by the fine-grained sparsity.

    Parameters
    ----------
    conv : torch.nn.Conv2d
        The conv2d module to be replaced
    masks : Tuple of the input masks, output masks and weight masks
        Tuple of the masks, for example
        ([input_m1, input_m2], [output_m], {'weight':weight_m})

    Returns
    -------
    torch.nn.Conv2d
        The new conv2d module
    """
    in_masks, output_mask, weight_masks = masks
    assert isinstance(conv, nn.Conv2d)
    # the conv layer should only have one input tensor
    if len(in_masks) != 1:
        raise InputsNumberError()

    in_mask = in_masks[0]

    weight_mask = weight_masks['weight']
    pruned_in, remained_in = convert_to_coarse_mask(in_mask, 1)
    pruned_out, remained_out = convert_to_coarse_mask(output_mask, 1)

    n_remained_in = weight_mask.size(1) * conv.groups - pruned_in.size(0)
    n_remained_out = weight_mask.size(0) - pruned_out.size(0)

    if n_remained_in != remained_in.size(0) or n_remained_out != remained_out.size(0):
        raise ShapeMisMatchError()

    k_size1, k_size2 = conv.kernel_size
    # Note: We should resolve the group dependency of the conv layers before
    # run into here.
    # check if the mask tensor meets the group dependency and calculate the
    # new number of the groups after pruning
    # the original step size of the input channel for each group
    ori_inchannel_step = int(conv.in_channels/conv.groups)
    # the original step size of the output channel for each group
    ori_outchannel_step = int(conv.out_channels/conv.groups)
    # calculate the new_in_channel_step and new_outchannel_step first
    new_inchannel_step = new_outchannel_step = None
    for groupid in range(conv.groups):
        in_start = groupid * ori_inchannel_step
        in_end = in_start + ori_inchannel_step
        out_start = groupid * ori_outchannel_step
        out_end = out_start + ori_outchannel_step
        current_input_index = list(
            filter(lambda x: in_start <= x and x < in_end, remained_in.tolist()))
        current_output_index = list(
            filter(lambda x: out_start <= x and x < out_end, remained_out.tolist()))
        # remap the global index to the group index
        if len(current_input_index) == 0:
            # if the whole group are pruned
            continue
        else:

            new_inchannel_step = len(current_input_index)
            new_outchannel_step = len(current_output_index)
            break
    tmp_weight = torch.ones(
        n_remained_out, new_inchannel_step, k_size1, k_size2)
    tmp_weight = tmp_weight.to(conv.weight.device)
    if new_inchannel_step == 0 or new_outchannel_step == 0:
        raise EmptyLayerError()
    if n_remained_in % new_inchannel_step != 0 or n_remained_out % new_outchannel_step != 0:
        raise UnBalancedGroupError()

    new_groups = 0
    for groupid in range(conv.groups):
        in_start = groupid * ori_inchannel_step
        in_end = in_start + ori_inchannel_step
        out_start = groupid * ori_outchannel_step
        out_end = out_start + ori_outchannel_step
        current_input_index = list(
            filter(lambda x: in_start <= x and x < in_end, remained_in.tolist()))
        current_output_index = list(
            filter(lambda x: out_start <= x and x < out_end, remained_out.tolist()))
        # remap the global index to the group index
        current_input_index = [x-in_start for x in current_input_index]
        if len(current_input_index) == 0:
            # if the whole group are pruned
            assert len(current_output_index) == 0
            continue
        # check if the number of remained channel of each group are the same
        if len(current_input_index) != new_inchannel_step or len(current_output_index) != new_outchannel_step:
            raise UnBalancedGroupError()

        # copy the weight into tmp_weight
        new_out_start = new_outchannel_step * new_groups
        new_out_end = new_out_start + new_outchannel_step
        tmp_weight[new_out_start:new_out_end] = torch.index_select(
            conv.weight[current_output_index], 1, torch.as_tensor(current_input_index, dtype=torch.long).to(conv.weight.device))
        new_groups += 1

    _logger.debug("replace conv2d with in_channels: %d, out_channels: %d",
                  n_remained_in, n_remained_out)

    # need_bias is a flag that indicates that if a conv layer need
    # bias, if the original conv doesn't have a bias and there is
    # no constant need to be folded into the bias, the need_bias is False.
    need_bias = conv.bias is not None
    new_conv = torch.nn.Conv2d(in_channels=n_remained_in,
                               out_channels=n_remained_out,
                               kernel_size=conv.kernel_size,
                               stride=conv.stride,
                               padding=conv.padding,
                               dilation=conv.dilation,
                               groups=new_groups,
                               bias=need_bias,
                               padding_mode=conv.padding_mode)

    new_conv.to(conv.weight.device)
    new_conv.weight.copy_(tmp_weight)

    # copy the bias data
    if conv.bias is not None:
        new_conv.bias.data.copy_(torch.index_select(
            conv.bias.data, 0, remained_out))

    return new_conv