in fast_grad_conv.py [0:0]
def forward(self, x):
if self.train:
del self.expanded_weight
del self.expanded_bias
batch_size = x.size(0)
self.expanded_weight, self.expanded_bias = FastGradConv2dFunction.apply(self.weight, self.bias, batch_size)
self.expanded_weight.requires_grad_(True)
self.expanded_weight.retain_grad()
if self.expanded_bias is not None:
self.expanded_bias.requires_grad_(True)
self.expanded_bias.retain_grad()
output = F.conv2d(x.view(1, -1, x.size(2), x.size(3)), self.expanded_weight, bias=self.expanded_bias,
stride=self.stride, padding=self.padding, dilation=self.dilation,
groups=batch_size)
return output.view(x.size(0), -1, output.size(2), output.size(3))
else:
return F.conv2d(x, self.weight, self.bias, stride=self.stride,
padding=self.padding, dilation=self.dilation)