in pytorch/sagemakercv/layers/nhwc/conv.py [0:0]
def backward(ctx, grad_y):
x, w = ctx.saved_variables
## if padding is used in fprop, we should pad grad_y here also for better perf
K = grad_y.shape[3]
is_padded = (K % 8) != 0
if is_padded:
K_padded = 8 * ((K + 7) // 8)
padded_grad_shape = [grad_y.shape[0], grad_y.shape[1], grad_y.shape[2], K_padded]
padded_grad = torch.zeros(padded_grad_shape, dtype = grad_y.dtype, device = grad_y.device)
padded_grad[:,:,:,:K] = grad_y
# print("padded grad shape", padded_grad.shape)
#grad_y = padded_grad
if ctx.need_bias_grad:
if not is_padded:
dx, dw, db = NHWC.cudnn_convolution_backward_with_bias_nhwc(x, grad_y, w,
ctx.padding, ctx.stride, ctx.dilation, ctx.groups,
torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic,
list(ctx.needs_input_grad[0:3]))
else:
dx, dw, db = NHWC.cudnn_convolution_backward_with_bias_nhwc(x, padded_grad, w,
ctx.padding, ctx.stride, ctx.dilation, ctx.groups,
torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic,
list(ctx.needs_input_grad[0:3]))
if ctx.needs_input_grad[0]:
if not is_padded:
return dx, dw, db, None, None, None, None
else:
return dx, dw[:K,:,:,:].contiguous(), db[:K].contiguous(), None, None, None, None
else:
if not is_padded:
return None, dw, db, None, None, None, None
else:
return None, dw[:K,:,:,:].contiguous(), db[:K].contiguous(), None, None, None, None
else:
if (not ctx.needs_input_grad[1] ):
return None, None, None, None, None, None, None
if not is_padded:
dx, dw = NHWC.cudnn_convolution_backward_nhwc(x, grad_y, w,
ctx.padding, ctx.stride, ctx.dilation, ctx.groups,
torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic,
list(ctx.needs_input_grad[0:2]))
else:
dx, dw = NHWC.cudnn_convolution_backward_nhwc(x, padded_grad, w,
ctx.padding, ctx.stride, ctx.dilation, ctx.groups,
torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic,
list(ctx.needs_input_grad[0:2]))
if (not ctx.needs_input_grad[1]):
return None, None, None, None, None, None, None
elif ctx.needs_input_grad[0]:
if not is_padded:
return dx, dw, None, None, None, None, None
else:
return dx, dw[:K,:,:,:].contiguous(), None, None, None, None, None
else:
if not is_padded:
return None, dw, None, None, None, None, None
else:
return None, dw[:K,:,:,:].contiguous(), None, None, None, None, None