def compute_gradients()

in benchmarks/horovod-resnet/train_imagenet_resnet_hvd.py [0:0]


    def compute_gradients(self, loss, var_list=None, *args, **kwargs):
        if var_list is None:
            var_list = tf.trainable_variables() + tf.get_collection(
                tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES
            )

        replaced_list = var_list

        if self._scale != 1.0:
            loss = tf.scalar_mul(self._scale, loss)

        gradvar = self._optimizer.compute_gradients(loss, replaced_list, *args, **kwargs)

        final_gradvar = []
        for orig_var, (grad, var) in zip(var_list, gradvar):
            if var is not orig_var:
                grad = tf.cast(grad, orig_var.dtype)
            if self._scale != 1.0:
                grad = tf.scalar_mul(1.0 / self._scale, grad)
            final_gradvar.append((grad, orig_var))

        return final_gradvar