in tensorflow_managed_spot_training_checkpointing/generate_cifar10_tfrecords.py [0:0]
def convert_to_tfrecord(input_files, output_file):
"""Converts a file to TFRecords."""
print('Generating %s' % output_file)
with tf.io.TFRecordWriter(output_file) as record_writer:
for input_file in input_files:
data_dict = read_pickle_from_file(input_file)
data = data_dict[b'data']
labels = data_dict[b'labels']
num_entries_in_batch = len(labels)
for i in range(num_entries_in_batch):
example = tf.train.Example(features=tf.train.Features(
feature={
'image': _bytes_feature(data[i].tobytes()),
'label': _int64_feature(labels[i])
}))
record_writer.write(example.SerializeToString())