def backward()

in crypten/gradients.py [0:0]


    def backward(ctx, grad_output):
        # Gradient function adapts code from:
        # https://github.com/pytorch/pytorch/blob/master/torch/nn/grad.py

        # get input, kernel, and sizes:
        input, kernel, padding, stride, dilation, groups = ctx.saved_tensors
        batch_size = input.size(0)
        out_channels, in_channels, kernel_size_y, kernel_size_x = kernel.size()
        in_channels *= groups
        assert input.size(1) == in_channels, "wrong number of input channels"
        assert grad_output.size(1) == out_channels, "wrong number of output channels"
        assert grad_output.size(0) == batch_size, "wrong batch size"

        # TODO: Implement conv2d gradient under following condition:
        if groups > 1 and input.size(1) > groups:
            raise NotImplementedError(
                "conv2d backward with groups > 1 and in_channels > groups not implemented"
            )

        # compute gradient with respect to input:
        # TODO: Eliminate dependency on torch internal function by implementing in util
        output_padding = torch.nn.grad._grad_input_padding(
            grad_output,
            input.size(),
            stride,
            padding,
            (kernel_size_y, kernel_size_x),
            dilation=dilation,
        )
        grad_input = grad_output.conv_transpose2d(
            kernel,
            stride=stride,
            padding=padding,
            output_padding=output_padding,
            groups=groups,
            dilation=dilation,
        )

        # compute gradient with respect to kernel:
        grad_output = grad_output.repeat(1, in_channels // groups, 1, 1)
        grad_output = grad_output.view(
            grad_output.size(0) * grad_output.size(1),
            1,
            grad_output.size(2),
            grad_output.size(3),
        )
        input = input.view(
            1, input.size(0) * input.size(1), input.size(2), input.size(3)
        )
        # dilation and stride are swapped based on PyTorch's conv2d_weight implementation
        grad_kernel = input.conv2d(
            grad_output,
            stride=dilation,
            padding=padding,
            dilation=stride,
            groups=in_channels * batch_size,
        )
        grad_kernel = grad_kernel.view(
            batch_size,
            grad_kernel.size(1) // batch_size,
            grad_kernel.size(2),
            grad_kernel.size(3),
        )
        grad_kernel = (
            grad_kernel.sum(0)
            .view(
                in_channels // groups,
                out_channels,
                grad_kernel.size(2),
                grad_kernel.size(3),
            )
            .transpose(0, 1)
        )
        grad_kernel = grad_kernel.narrow(2, 0, kernel_size_y)
        grad_kernel = grad_kernel.narrow(3, 0, kernel_size_x)
        return (grad_input, grad_kernel)