in benchmarks/horovod-resnet/train_imagenet_resnet_hvd.py [0:0]
def flatten2d(self, inputs):
x = inputs
if self.data_format != "channel_last":
# Note: This ensures the output order matches that of NHWC networks
x = tf.transpose(x, [0, 2, 3, 1])
input_shape = x.get_shape().as_list()
num_inputs = 1
for dim in input_shape[1:]:
num_inputs *= dim
return tf.reshape(x, [-1, num_inputs], name="flatten")