in benchmarks/horovod-resnet/train_imagenet_resnet_hvd.py [0:0]
def compute_gradients(self, loss, var_list=None, *args, **kwargs):
if var_list is None:
var_list = tf.trainable_variables() + tf.get_collection(
tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES
)
replaced_list = var_list
if self._scale != 1.0:
loss = tf.scalar_mul(self._scale, loss)
gradvar = self._optimizer.compute_gradients(loss, replaced_list, *args, **kwargs)
final_gradvar = []
for orig_var, (grad, var) in zip(var_list, gradvar):
if var is not orig_var:
grad = tf.cast(grad, orig_var.dtype)
if self._scale != 1.0:
grad = tf.scalar_mul(1.0 / self._scale, grad)
final_gradvar.append((grad, orig_var))
return final_gradvar