in models/densenet_efficient_multi_gpu.py [0:0]
def backward(self, grad_output):
# Turn off bn training status, and temporarily reset statistics
training = self.efficient_batch_norm.training
self.curr_running_mean.copy_(self.efficient_batch_norm.running_mean)
self.curr_running_var.copy_(self.efficient_batch_norm.running_var)
# self.efficient_batch_norm.training = False
self.efficient_batch_norm.running_mean.copy_(self.prev_running_mean)
self.efficient_batch_norm.running_var.copy_(self.prev_running_var)
# Recompute concat and BN
cat_output = self.efficient_cat.forward(*self.inputs)
bn_output = self.efficient_batch_norm.forward(self.bn_weight, self.bn_bias, cat_output)
relu_output = self.efficient_relu.forward(bn_output)
# Conv backward
conv_weight_grad, _, conv_grad_output = self.efficient_conv.backward(
self.conv_weight, None, relu_output, grad_output)
# ReLU backward
relu_grad_output = self.efficient_relu.backward(bn_output, conv_grad_output)
# BN backward
cat_output = self.efficient_cat.forward(*self.inputs) # recompute cat_output because bn_output override the storage (L481)
# multi_gpu version is slightly different from the single gpu that
# we only use one shared_allocation for both BN and Cat
self.efficient_batch_norm.running_mean.copy_(self.curr_running_mean)
self.efficient_batch_norm.running_var.copy_(self.curr_running_var)
bn_weight_grad, bn_bias_grad, bn_grad_output = self.efficient_batch_norm.backward(
self.bn_weight, self.bn_bias, cat_output, relu_grad_output)
# Input backward
grad_inputs = self.efficient_cat.backward(bn_grad_output)
# Reset bn training status and statistics
self.efficient_batch_norm.training = training
self.efficient_batch_norm.running_mean.copy_(self.curr_running_mean)
self.efficient_batch_norm.running_var.copy_(self.curr_running_var)
return tuple([bn_weight_grad, bn_bias_grad, conv_weight_grad] + list(grad_inputs))