def deserialize_image_record()

in benchmarks/horovod-resnet/train_imagenet_resnet_hvd.py [0:0]


def deserialize_image_record(record):
    feature_map = {
        "image/encoded": tf.FixedLenFeature([], tf.string, ""),
        "image/class/label": tf.FixedLenFeature([1], tf.int64, -1),
        "image/class/text": tf.FixedLenFeature([], tf.string, ""),
        "image/object/bbox/xmin": tf.VarLenFeature(dtype=tf.float32),
        "image/object/bbox/ymin": tf.VarLenFeature(dtype=tf.float32),
        "image/object/bbox/xmax": tf.VarLenFeature(dtype=tf.float32),
        "image/object/bbox/ymax": tf.VarLenFeature(dtype=tf.float32),
    }
    with tf.name_scope("deserialize_image_record"):
        obj = tf.parse_single_example(record, feature_map)
        imgdata = obj["image/encoded"]
        label = tf.cast(obj["image/class/label"], tf.int32)
        bbox = tf.stack(
            [obj["image/object/bbox/%s" % x].values for x in ["ymin", "xmin", "ymax", "xmax"]]
        )
        bbox = tf.transpose(tf.expand_dims(bbox, 0), [0, 2, 1])
        text = obj["image/class/text"]
        return imgdata, label, bbox, text