def backward()

in pytorch/sagemakercv/layers/nhwc/conv.py [0:0]


    def backward(ctx, grad_y):
        x, w = ctx.saved_variables
        ## if padding is used in fprop, we should pad grad_y here also for better perf
        K = grad_y.shape[3]
        is_padded = (K % 8) != 0
        if is_padded:
            K_padded = 8 * ((K + 7) // 8)
            padded_grad_shape = [grad_y.shape[0], grad_y.shape[1], grad_y.shape[2], K_padded]
            padded_grad = torch.zeros(padded_grad_shape, dtype = grad_y.dtype, device = grad_y.device)
            padded_grad[:,:,:,:K] = grad_y
           # print("padded grad shape", padded_grad.shape)
            #grad_y = padded_grad

        if ctx.need_bias_grad:
            if not is_padded:
                dx, dw, db = NHWC.cudnn_convolution_backward_with_bias_nhwc(x, grad_y, w,
                                                       ctx.padding, ctx.stride, ctx.dilation, ctx.groups,
                                                       torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic,
                                                       list(ctx.needs_input_grad[0:3]))
            else:
                dx, dw, db = NHWC.cudnn_convolution_backward_with_bias_nhwc(x, padded_grad, w,
                                                       ctx.padding, ctx.stride, ctx.dilation, ctx.groups,
                                                       torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic,
                                                       list(ctx.needs_input_grad[0:3]))
            if ctx.needs_input_grad[0]:
                if not is_padded:
                    return dx, dw, db, None, None, None, None
                else:
                    return dx, dw[:K,:,:,:].contiguous(), db[:K].contiguous(), None, None, None, None
            else:
                if not is_padded:
                    return None, dw, db, None, None, None, None
                else:
                    return None, dw[:K,:,:,:].contiguous(), db[:K].contiguous(), None, None, None, None
        else:
            if (not ctx.needs_input_grad[1] ):
                return None, None, None, None, None, None, None 
            if not is_padded:
                dx, dw = NHWC.cudnn_convolution_backward_nhwc(x, grad_y, w,
                                                       ctx.padding, ctx.stride, ctx.dilation, ctx.groups,
                                                       torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic,
                                                       list(ctx.needs_input_grad[0:2]))
            else:
                dx, dw = NHWC.cudnn_convolution_backward_nhwc(x, padded_grad, w,
                                                       ctx.padding, ctx.stride, ctx.dilation, ctx.groups,
                                                       torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic,
                                                       list(ctx.needs_input_grad[0:2]))
            if (not ctx.needs_input_grad[1]):
                return None, None, None, None, None, None, None  
            elif ctx.needs_input_grad[0]:
                if not is_padded:
                    return dx, dw, None, None, None, None, None
                else:
                    return dx, dw[:K,:,:,:].contiguous(), None, None, None, None, None
            else:
                if not is_padded:
                    return None, dw, None, None, None, None, None
                else:
                    return None, dw[:K,:,:,:].contiguous(), None, None, None, None, None