in benchmarks/horovod-resnet/train_imagenet_resnet_hvd.py [0:0]
def spatial_average2d(self, inputs):
shape = inputs.get_shape().as_list()
if self.data_format == "channels_last":
n, h, w, c = shape
else:
n, c, h, w = shape
n = -1 if n is None else n
x = tf.layers.average_pooling2d(inputs, (h, w), (1, 1), data_format=self.data_format)
return tf.reshape(x, [n, c])