def write_records()

in 06_preprocessing/jpeg_to_tfrecord_tft.py [0:0]


def write_records(OUTPUT_DIR, splits, split):
    # same 80:10:10 split
    # The flowers dataset takes about 1GB, so 20 files means 50MB each
    nshards = 16 if (split == 'train') else 2
    _ = (splits
         | 'only_{}'.format(split) >> beam.FlatMap(
             lambda x: yield_records_for_split(x, split))
         | 'write_{}'.format(split) >> beam.io.tfrecordio.WriteToTFRecord(
             os.path.join(OUTPUT_DIR, split),
             file_name_suffix='.gz', num_shards=nshards)
        )