def residual2d()

in benchmarks/horovod-resnet/train_imagenet_resnet_hvd.py [0:0]


    def residual2d(self, inputs, network, units=None, scale=1.0, activate=False):
        outputs = network(inputs)
        c_axis = -1 if self.data_format == "channels_last" else 1
        h_axis = 1 if self.data_format == "channels_last" else 2
        w_axis = h_axis + 1
        ishape, oshape = [y.get_shape().as_list() for y in [inputs, outputs]]
        ichans, ochans = ishape[c_axis], oshape[c_axis]
        strides = (
            (ishape[h_axis] - 1) // oshape[h_axis] + 1,
            (ishape[w_axis] - 1) // oshape[w_axis] + 1,
        )
        with tf.name_scope("residual"):
            if ochans != ichans or strides[0] != 1 or strides[1] != 1:
                inputs = self.conv2d_linear(inputs, units, 1, strides, "SAME")
            x = inputs + scale * outputs
            if activate:
                x = self.activate(x)
        return x