def flatten2d()

in benchmarks/horovod-resnet/train_imagenet_resnet_hvd.py [0:0]


    def flatten2d(self, inputs):
        x = inputs
        if self.data_format != "channel_last":
            # Note: This ensures the output order matches that of NHWC networks
            x = tf.transpose(x, [0, 2, 3, 1])
        input_shape = x.get_shape().as_list()
        num_inputs = 1
        for dim in input_shape[1:]:
            num_inputs *= dim
        return tf.reshape(x, [-1, num_inputs], name="flatten")