def batch_norm()

in benchmarks/horovod-resnet/train_imagenet_resnet_hvd.py [0:0]


    def batch_norm(self, inputs, **kwargs):
        all_kwargs = dict(self.batch_norm_config)
        all_kwargs.update(kwargs)
        data_format = "NHWC" if self.data_format == "channels_last" else "NCHW"
        return tf.contrib.layers.batch_norm(
            inputs, is_training=self.training, data_format=data_format, fused=True, **all_kwargs
        )