def create_distributed_optimizer()

in horovod/_keras/__init__.py [0:0]


def create_distributed_optimizer(keras, optimizer, name, device_dense, device_sparse,
                                 compression, sparse_as_dense, gradient_predivide_factor):
    class _DistributedOptimizer(keras.optimizers.Optimizer):
        _HAS_AGGREGATE_GRAD = True

        def __init__(self, **kwargs):
            self._name = name or "Distributed%s" % self.__class__.__base__.__name__
            self._device_dense = device_dense
            self._device_sparse = device_sparse
            self._compression = compression
            self._sparse_as_dense = sparse_as_dense
            self._aggregated_gradients = False
            self._gradient_predivide_factor = gradient_predivide_factor
            super(self.__class__, self).__init__(**kwargs)

        def get_gradients(self, loss, params):
            """
            Compute gradients of all trainable variables.

            See Optimizer.get_gradients() for more info.

            In DistributedOptimizer, get_gradients() is overriden to also
            allreduce the gradients before returning them.
            """
            gradients = super(self.__class__, self).get_gradients(loss, params)
            return self._allreduce(gradients)

        def _aggregate_gradients(self, grads_and_vars):
            grads, vars = list(zip(*grads_and_vars))
            aggregated_grads = self._allreduce(grads)
            if _PRE_TF_2_4_0:
                # Prior to TF 2.4.0, this function was expected to return only a list of
                # grads, not a list of (grad, var) tuples.
                return aggregated_grads
            return list(zip(aggregated_grads, vars))

        def _allreduce(self, gradients):
            self._aggregated_gradients = True
            if hvd.size() > 1:
                if self._gradient_predivide_factor != 1.0:
                    # Perform averaging via pre/postscaling factors.
                    # Split average operation across pre/postscale factors
                    prescale_factor = 1.0 / gradient_predivide_factor
                    postscale_factor = gradient_predivide_factor / hvd.size()
                    do_average = False
                else:
                    prescale_factor = 1.0
                    postscale_factor = 1.0
                    do_average = True

                averaged_gradients = []
                with tf.name_scope(self._name + "_Allreduce"):
                    for grad in gradients:
                        if grad is not None:
                            if self._sparse_as_dense and \
                                    isinstance(grad, tf.IndexedSlices):
                                grad = tf.convert_to_tensor(grad)
                            avg_grad = hvd.allreduce(grad,
                                                     average=do_average,
                                                     device_dense=self._device_dense,
                                                     device_sparse=self._device_sparse,
                                                     compression=self._compression,
                                                     prescale_factor=prescale_factor,
                                                     postscale_factor=postscale_factor)
                            averaged_gradients.append(avg_grad)
                        else:
                            averaged_gradients.append(None)
                    return averaged_gradients
            else:
                return gradients

        def apply_gradients(self, *args, **kwargs):
            results = super(self.__class__, self).apply_gradients(*args, **kwargs)
            if not self._aggregated_gradients:
                raise Exception('`apply_gradients()` was called without a call to '
                                '`get_gradients()` or `_aggregate_gradients`. If you\'re '
                                'using TensorFlow 2.0, please specify '
                                '`experimental_run_tf_function=False` in `compile()`.')
            return results

    # We dynamically create a new class that inherits from the optimizer that was passed in.
    # The goal is to override get_gradients() method with an allreduce implementation.
    # This class will have the same name as the optimizer it's wrapping, so that the saved
    # model could be easily restored without Horovod.
    cls = type(optimizer.__class__.__name__, (optimizer.__class__,),
               dict(_DistributedOptimizer.__dict__))
    return cls.from_config(optimizer.get_config())