in tensorflow_managed_spot_training_checkpointing/source_dir/cifar10_keras_main.py [0:0]
def _dataset_parser(value):
"""Parse a CIFAR-10 record from value."""
featdef = {
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
}
example = tf.parse_single_example(value, featdef)
image = tf.decode_raw(example['image'], tf.uint8)
image.set_shape([DEPTH * HEIGHT * WIDTH])
# Reshape from [depth * height * width] to [depth, height, width].
image = tf.cast(
tf.transpose(tf.reshape(image, [DEPTH, HEIGHT, WIDTH]), [1, 2, 0]),
tf.float32,
)
label = tf.cast(example['label'], tf.int32)
image = _train_preprocess_fn(image)
return image, tf.one_hot(label, NUM_CLASSES)