in src/script.py [0:0]
def generate_tf_records(base_folder, input_files, output_file, n_image, slide=None):
record_file = output_file
count = n_image
with tf.io.TFRecordWriter(record_file) as writer:
while count:
filename, label = random.choice(input_files)
temp_img = plt.imread(os.path.join(base_folder, filename))
if temp_img.shape != (512, 512, 3):
continue
count -= 1
image_string = np.float32(temp_img).tobytes()
slide_string = slide.encode('utf-8') if slide else None
tf_example = image_example(image_string, label, slide_string)
writer.write(tf_example.SerializeToString())