def backward()

in crypten/gradients.py [0:0]


    def backward(ctx, grad_output):
        # Gradient function adapts code from:
        # https://github.com/pytorch/pytorch/blob/master/torch/nn/grad.py

        # get input, kernel, and sizes:
        input, kernel, padding, stride, dilation, groups = ctx.saved_tensors
        batch_size = input.size(0)
        out_channels, in_channels, kernel_size = kernel.size()
        in_channels *= groups
        assert input.size(1) == in_channels, "wrong number of input channels"
        assert grad_output.size(1) == out_channels, "wrong number of output channels"
        assert grad_output.size(0) == batch_size, "wrong batch size"

        # TODO: Implement conv1d gradient under following condition:
        if groups > 1 and input.size(1) > groups:
            raise NotImplementedError(
                "conv1d backward with groups > 1 and in_channels > groups not implemented"
            )

        # compute gradient with respect to input:
        # TODO: Eliminate dependency on torch internal function by implementing in util
        output_padding = torch.nn.grad._grad_input_padding(
            grad_output,
            input.size(),
            stride,
            padding,
            (kernel_size,),
            dilation=dilation,
        )
        grad_input = grad_output.conv_transpose1d(
            kernel,
            stride=stride,
            padding=padding,
            output_padding=output_padding,
            groups=groups,
            dilation=dilation,
        )

        # compute gradient with respect to kernel:
        grad_output = grad_output.repeat(1, in_channels // groups, 1)
        grad_output = grad_output.view(
            grad_output.size(0) * grad_output.size(1), 1, grad_output.size(2)
        )
        input = input.view(1, input.size(0) * input.size(1), input.size(2))
        grad_kernel = input.conv1d(
            grad_output,
            stride=dilation,
            padding=padding,
            dilation=stride,
            groups=in_channels * batch_size,
        )
        grad_kernel = grad_kernel.view(
            batch_size, grad_kernel.size(1) // batch_size, grad_kernel.size(2)
        )
        grad_kernel = grad_kernel.sum(dim=0)
        grad_kernel = grad_kernel.view(
            in_channels // groups, out_channels, grad_kernel.size(1)
        )
        grad_kernel = grad_kernel.transpose(0, 1).narrow(2, 0, kernel_size)
        return (grad_input, grad_kernel)