def inference_resnet_v1_impl()

in benchmarks/horovod-resnet/train_imagenet_resnet_hvd.py [0:0]


def inference_resnet_v1_impl(builder, inputs, layer_counts, basic=False):
    x = inputs
    x = builder.pad2d(x, 3)
    x = builder.conv2d(x, 64, 7, 2, "VALID")
    x = builder.max_pooling2d(x, 3, 2, "SAME")
    for i in range(layer_counts[0]):
        x = resnet_bottleneck_v1(builder, x, 256, 64, 1, basic)
    for i in range(layer_counts[1]):
        x = resnet_bottleneck_v1(builder, x, 512, 128, 2 if i == 0 else 1, basic)
    for i in range(layer_counts[2]):
        x = resnet_bottleneck_v1(builder, x, 1024, 256, 2 if i == 0 else 1, basic)
    for i in range(layer_counts[3]):
        x = resnet_bottleneck_v1(builder, x, 2048, 512, 2 if i == 0 else 1, basic)
    return builder.spatial_average2d(x)