in benchmarks/horovod-resnet/train_imagenet_resnet_hvd.py [0:0]
def batch_norm(self, inputs, **kwargs):
all_kwargs = dict(self.batch_norm_config)
all_kwargs.update(kwargs)
data_format = "NHWC" if self.data_format == "channels_last" else "NCHW"
return tf.contrib.layers.batch_norm(
inputs, is_training=self.training, data_format=data_format, fused=True, **all_kwargs
)