def replace_convtranspose2d()

in nni/compression/pytorch/speedup/compress_modules.py [0:0]


def replace_convtranspose2d(convtrans, masks):
    """
    We need anothor replace function for
    convtranspose2d, because the layout of
    the weight is different from traditional
    conv layers. The layout of the weight is [N_in, N_out, ksize_1, ksize_2]
    Parameters
    ----------
    convtrans : torch.nn.ConvTranspose2d
        The conv2d module to be replaced
    masks : Tuple of the input masks, output masks and weight masks
        Tuple of the masks, for example
        ([input_m1, input_m2], [output_m], {'weight':weight_m})
    Returns
    -------
    torch.nn.ConvTranspose2d
        The new conv2d module
    """
    in_masks, output_mask, weight_masks = masks
    assert isinstance(convtrans, torch.nn.ConvTranspose2d)
    if len(in_masks) != 1:
        raise InputsNumberError()
    in_mask = in_masks[0]

    weight_mask = weight_masks['weight']
    pruned_in, remained_in = convert_to_coarse_mask(in_mask, 1)
    pruned_out, remained_out = convert_to_coarse_mask(output_mask, 1)
    # ConvTranspose2d has the weight shape of [N_in, N_out/groups, k1, k2]
    n_remained_in = weight_mask.size(0) - pruned_in.size(0)
    n_remained_out = weight_mask.size(
        1) * convtrans.groups - pruned_out.size(0)
    if n_remained_in != remained_in.size(0) or n_remained_out != remained_out.size(0):
        raise ShapeMisMatchError()

    k_size1, k_size2 = convtrans.kernel_size
    # Note: we should resolve the group dependency of the convtrans layers before
    # run into this function
    ori_inchannel_step = int(convtrans.in_channels/convtrans.groups)
    ori_outchannel_step = int(convtrans.out_channels/convtrans.groups)
    new_inchannel_step = new_outchannel_step = None
    for groupid in range(convtrans.groups):
        in_start = groupid * ori_inchannel_step
        in_end = in_start + ori_inchannel_step
        out_start = groupid * ori_outchannel_step
        out_end = out_start + ori_outchannel_step
        current_input_index = list(
            filter(lambda x: in_start <= x and x < in_end, remained_in.tolist()))
        current_output_index = list(
            filter(lambda x: out_start <= x and x < out_end, remained_out.tolist()))
        if len(current_input_index) == 0:
            # if the whole group are pruned
            continue
        else:
            new_inchannel_step = len(current_input_index)
            new_outchannel_step = len(current_output_index)
            break
    tmp_weight = torch.ones(
        n_remained_in, new_outchannel_step, k_size1, k_size2)
    tmp_weight = tmp_weight.to(convtrans.weight.device)

    if new_inchannel_step == 0 or new_outchannel_step == 0:
        raise EmptyLayerError()
    if n_remained_in % new_inchannel_step != 0 or n_remained_out % new_outchannel_step != 0:
        raise UnBalancedGroupError()

    new_groups = 0
    for groupid in range(convtrans.groups):
        # copy the weights of this group
        in_start = groupid * ori_inchannel_step
        in_end = in_start + ori_inchannel_step
        out_start = groupid * ori_outchannel_step
        out_end = out_start + ori_outchannel_step
        current_input_index = list(
            filter(lambda x: in_start <= x and x < in_end, remained_in.tolist()))
        current_output_index = list(
            filter(lambda x: out_start <= x and x < out_end, remained_out.tolist()))
        # remap the global index to the group index
        # in the convtranspose layer, the groups are on
        # the output channel dimension
        current_output_index = [x-out_start for x in current_output_index]
        if len(current_input_index) == 0:
            # if the whole group are pruned
            assert len(current_output_index) == 0
            continue
        # check if the number of remained channel of each group are the same
        if len(current_input_index) != new_inchannel_step or len(current_output_index) != new_outchannel_step:
            raise UnBalancedGroupError()

        # copy the weight into tmp_weight
        new_in_start = new_inchannel_step * new_groups
        new_in_end = new_in_start + new_inchannel_step
        tmp_weight[new_in_start:new_in_end] = torch.index_select(
            convtrans.weight[current_input_index], 1, torch.as_tensor(current_output_index, dtype=torch.long).to(convtrans.weight.device))
        new_groups += 1

    _logger.debug('Replace convtranspose2d with in_channels:%d out_channels:%d',
                  n_remained_in, n_remained_out)
    new_convtrans = torch.nn.ConvTranspose2d(in_channels=n_remained_in,
                                             out_channels=n_remained_out,
                                             kernel_size=convtrans.kernel_size,
                                             stride=convtrans.stride,
                                             padding=convtrans.padding,
                                             dilation=convtrans.dilation,
                                             groups=new_groups,
                                             bias=convtrans.bias is not None,
                                             padding_mode=convtrans.padding_mode)
    new_convtrans.to(convtrans.weight.device)
    new_convtrans.weight.copy_(tmp_weight)
    if convtrans.bias is not None:
        if output_mask is not None:
            new_convtrans.bias.data[:] = torch.index_select(
                convtrans.bias.data, 0, remained_out)
        else:
            new_convtrans.bias.data.copy_(convtrans.bias.data)
    return new_convtrans