03_image_models/flowers.py (132 lines of code) (raw):

import math, re, os, sys import tensorflow as tf import numpy as np from sklearn.metrics import f1_score, precision_score, recall_score, confusion_matrix print("Tensorflow version " + tf.__version__) AUTO = tf.data.experimental.AUTOTUNE try: # detect TPUs tpu = tf.distribute.cluster_resolver.TPUClusterResolver.connect() strategy = tf.distribute.TPUStrategy(tpu) except ValueError: # detect GPUs or multi-GPU machines strategy = tf.distribute.MirroredStrategy() print("REPLICAS: ", strategy.num_replicas_in_sync) GCS_DS_PATH = "gs://practical-ml-vision-book-data/flowers_104_tfr" # Settings for TPUv3. When running on hardware with less memory such as a TPUv2 (Colab) # or a GPU, you might have to use lower BATCH_SIZE and/or IMAGE_SIZE values. IMAGE_SIZE = [512, 512] # available image sizes in flowers104 dataset: 512x512, 331x331, 224x224, 192,192 BATCH_SIZE = 32 * strategy.num_replicas_in_sync EPOCHS = 13 # Learning rate schedule LR_START = 0.00001 LR_MAX = 0.0001 * strategy.num_replicas_in_sync LR_MIN = 0.00001 LR_RAMPUP_EPOCHS = 3 LR_SUSTAIN_EPOCHS = 3 LR_EXP_DECAY = .5 GCS_PATH_SELECT = { # available image sizes 192: GCS_DS_PATH + '/tfrecords-jpeg-192x192', 224: GCS_DS_PATH + '/tfrecords-jpeg-224x224', 331: GCS_DS_PATH + '/tfrecords-jpeg-331x331', 512: GCS_DS_PATH + '/tfrecords-jpeg-512x512' } GCS_PATH = GCS_PATH_SELECT[IMAGE_SIZE[0]] # This dataset is split three ways, training, validation, test # but we will use it split two ways only: training and validation. TRAINING_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/train/*.tfrec') + tf.io.gfile.glob(GCS_PATH + '/val/*.tfrec') VALIDATION_FILENAMES = tf.io.gfile.glob(GCS_PATH + '/test/*.tfrec') CLASSES = ['pink primrose', 'hard-leaved pocket orchid', 'canterbury bells', 'sweet pea', 'wild geranium', 'tiger lily', 'moon orchid', 'bird of paradise', 'monkshood', 'globe thistle', # 00 - 09 'snapdragon', "colt's foot", 'king protea', 'spear thistle', 'yellow iris', 'globe-flower', 'purple coneflower', 'peruvian lily', 'balloon flower', 'giant white arum lily', # 10 - 19 'fire lily', 'pincushion flower', 'fritillary', 'red ginger', 'grape hyacinth', 'corn poppy', 'prince of wales feathers', 'stemless gentian', 'artichoke', 'sweet william', # 20 - 29 'carnation', 'garden phlox', 'love in the mist', 'cosmos', 'alpine sea holly', 'ruby-lipped cattleya', 'cape flower', 'great masterwort', 'siam tulip', 'lenten rose', # 30 - 39 'barberton daisy', 'daffodil', 'sword lily', 'poinsettia', 'bolero deep blue', 'wallflower', 'marigold', 'buttercup', 'daisy', 'common dandelion', # 40 - 49 'petunia', 'wild pansy', 'primula', 'sunflower', 'lilac hibiscus', 'bishop of llandaff', 'gaura', 'geranium', 'orange dahlia', 'pink-yellow dahlia', # 50 - 59 'cautleya spicata', 'japanese anemone', 'black-eyed susan', 'silverbush', 'californian poppy', 'osteospermum', 'spring crocus', 'iris', 'windflower', 'tree poppy', # 60 - 69 'gazania', 'azalea', 'water lily', 'rose', 'thorn apple', 'morning glory', 'passion flower', 'lotus', 'toad lily', 'anthurium', # 70 - 79 'frangipani', 'clematis', 'hibiscus', 'columbine', 'desert-rose', 'tree mallow', 'magnolia', 'cyclamen ', 'watercress', 'canna lily', # 80 - 89 'hippeastrum ', 'bee balm', 'pink quill', 'foxglove', 'bougainvillea', 'camellia', 'mallow', 'mexican petunia', 'bromelia', 'blanket flower', # 90 - 99 'trumpet creeper', 'blackberry lily', 'common tulip', 'wild rose'] # 100 - 102 def lrfn(epoch): if epoch < LR_RAMPUP_EPOCHS: lr = (LR_MAX - LR_START) / LR_RAMPUP_EPOCHS * epoch + LR_START elif epoch < LR_RAMPUP_EPOCHS + LR_SUSTAIN_EPOCHS: lr = LR_MAX else: lr = (LR_MAX - LR_MIN) * LR_EXP_DECAY**(epoch - LR_RAMPUP_EPOCHS - LR_SUSTAIN_EPOCHS) + LR_MIN return lr lr_callback = tf.keras.callbacks.LearningRateScheduler(lrfn, verbose=True) # DATASET def decode_image(image_data): image = tf.image.decode_jpeg(image_data, channels=3) # decoded inamge in uint8 format range [0,255] image = tf.reshape(image, [*IMAGE_SIZE, 3]) # explicit size needed for TPU return image def read_tfrecord(example): TFREC_FORMAT = { "image": tf.io.FixedLenFeature([], tf.string), # tf.string means bytestring "class": tf.io.FixedLenFeature([], tf.int64), # shape [] means single element "id": tf.io.FixedLenFeature([], tf.string), # shape [] means single element } example = tf.io.parse_single_example(example, TFREC_FORMAT) image = decode_image(example['image']) label = tf.cast(example['class'], tf.int32) idnum = example['id'] # image id, not used return image, label # returns a dataset of (image, label) pairs def load_dataset(filenames, ordered=False): # Read from TFRecords. For optimal performance, reading from multiple files at once and # disregarding data order. Order does not matter since we will be shuffling the data anyway. ignore_order = tf.data.Options() if not ordered: ignore_order.experimental_deterministic = False # disable order, increase speed dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=AUTO) # automatically interleaves reads from multiple files dataset = dataset.with_options(ignore_order) # uses data as soon as it streams in, rather than in its original order dataset = dataset.map(read_tfrecord, num_parallel_calls=AUTO) # returns a dataset of (image, label) pairs return dataset def data_augment(image, label): # data augmentation. Thanks to the dataset.prefetch(AUTO) statement in the next function (below), # this happens essentially for free on TPU. Data pipeline code is executed on the "CPU" part # of the TPU while the TPU itself is computing gradients. image = tf.image.random_flip_left_right(image) #image = tf.image.random_saturation(image, 0, 2) return image, label def get_training_dataset(): dataset = load_dataset(TRAINING_FILENAMES) dataset = dataset.map(data_augment, num_parallel_calls=AUTO) dataset = dataset.repeat() # the training dataset must repeat for several epochs dataset = dataset.shuffle(2048) dataset = dataset.batch(BATCH_SIZE) dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size) return dataset def get_validation_dataset(ordered=False): dataset = load_dataset(VALIDATION_FILENAMES, ordered=ordered) dataset = dataset.batch(BATCH_SIZE) dataset = dataset.cache() dataset = dataset.prefetch(AUTO) # prefetch next batch while training (autotune prefetch buffer size) return dataset def count_data_items(filenames): # the number of data items is written in the name of the .tfrec files, i.e. flowers00-230.tfrec = 230 data items n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames] return np.sum(n) NUM_TRAINING_IMAGES = count_data_items(TRAINING_FILENAMES) NUM_VALIDATION_IMAGES = count_data_items(VALIDATION_FILENAMES) STEPS_PER_EPOCH = NUM_TRAINING_IMAGES // BATCH_SIZE VALIDATION_STEPS = -(-NUM_VALIDATION_IMAGES // BATCH_SIZE) # The "-(-//)" trick rounds up instead of down :-) print('Dataset: {} training images, {} validation images'.format(NUM_TRAINING_IMAGES, NUM_VALIDATION_IMAGES)) # MODEL with strategy.scope(): pretrained_model = tf.keras.applications.Xception(weights='imagenet', include_top=False) #pretrained_model = efficientnet.tfkeras.EfficientNetB7(weights='imagenet', include_top=False, input_shape=[*IMAGE_SIZE, 3]) pretrained_model.trainable = True # fine-tuning model = tf.keras.Sequential([ # convert image format from int [0,255] to the format expected by this model tf.keras.layers.Lambda(lambda data: tf.keras.applications.xception.preprocess_input(tf.cast(data, tf.float32)), input_shape=[*IMAGE_SIZE, 3]), pretrained_model, tf.keras.layers.GlobalAveragePooling2D(), tf.keras.layers.Dense(len(CLASSES), activation='softmax', name='flower_prob') ]) model.compile( optimizer='adam', loss = 'sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'], steps_per_execution=8 ) model.summary() # TRAINING history = model.fit(get_training_dataset(), steps_per_epoch=STEPS_PER_EPOCH, epochs=EPOCHS, validation_data=get_validation_dataset(), validation_steps=VALIDATION_STEPS, callbacks=[lr_callback]) # CONFUSION MATRIX cmdataset = get_validation_dataset(ordered=True) # since we are splitting the dataset and iterating separately on images and labels, order matters. images_ds = cmdataset.map(lambda image, label: image) labels_ds = cmdataset.map(lambda image, label: label).unbatch() cm_correct_labels = next(iter(labels_ds.batch(NUM_VALIDATION_IMAGES))).numpy() # get everything as one batch cm_probabilities = model.predict(images_ds, steps=VALIDATION_STEPS) cm_predictions = np.argmax(cm_probabilities, axis=-1) print("Correct labels: ", cm_correct_labels.shape, cm_correct_labels) print("Predicted labels: ", cm_predictions.shape, cm_predictions) cmat = confusion_matrix(cm_correct_labels, cm_predictions, labels=range(len(CLASSES))) score = f1_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro') precision = precision_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro') recall = recall_score(cm_correct_labels, cm_predictions, labels=range(len(CLASSES)), average='macro') cmat = (cmat.T / cmat.sum(axis=1)).T # normalized print('f1 score: {:.3f}, precision: {:.3f}, recall: {:.3f}'.format(score, precision, recall))