def choose_cnn()

in coinrun/policies.py [0:0]


def choose_cnn(images):
    arch = Config.ARCHITECTURE
    scaled_images = tf.cast(images, tf.float32) / 255.
    dropout_assign_ops = []

    if arch == 'nature':
        out = nature_cnn(scaled_images)
    elif arch == 'impala':
        out, dropout_assign_ops = impala_cnn(scaled_images)
    elif arch == 'impalalarge':
        out, dropout_assign_ops = impala_cnn(scaled_images, depths=[32, 64, 64, 64, 64])
    else:
        assert(False)

    return out, dropout_assign_ops