# Tensorflow Model Optimization Toolkit (TMO)

In this notebook, we will demonstrate how to use TMO to optimize a model for deployment. We train a model on the MNIST dataset and then optimize it using TMO. We will then compare the size and accuracy of the optimized model with the original model.

## Setup TMO

First, we install TMO and import the required packages.

In [None]:
%pip install -q tensorflow
%pip install -q tensorflow-model-optimization

In [None]:
import tensorflow as tf
import tensorflow_model_optimization as tfmot
from tensorflow import keras
import pathlib
import numpy as np


## Post Training Quantization

Post-training quantization tool convert weights of trained model from 32 bit to 8 bit precision. The tool convert already-trained float TensorFlow model when we convert it to TensorFlow Lite format using the [TensorFlow Lite Converter](https://www.tensorflow.org/lite/models/convert/)

### Load MNIST dataset

We load the MNIST dataset from Keras and prepare it for training.

In [None]:
# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 and 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

### Train the Model

Next, we define a CNN model and train it on the MNIST dataset.

In [None]:
# Define the model architecture
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation=tf.nn.relu),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
model.fit(
  train_images,
  train_labels,
  epochs=1,
  validation_data=(test_images, test_labels)
)

### Convert Model to TFLite

After training the model, we convert it to [TFLite](https://www.tensorflow.org/lite/guide ) format and then perform quantization during the conversion.

In [None]:
tflite_models_dir = pathlib.Path("notebooks/Unit 9 - Model Optimization/models")
tflite_models_dir.mkdir(exist_ok=True, parents=True)
converter = tf.lite.TFLiteConverter.from_keras_model(model)

# without quantization
tflite_model = converter.convert()
tflite_model_file = tflite_models_dir/"original_model.tflite"
tflite_model_file.write_bytes(tflite_model)

# with quantization
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant_model = converter.convert()
tflite_model_quant_file = tflite_models_dir/"quantized_model.tflite"
tflite_model_quant_file.write_bytes(tflite_quant_model)

### Check Model Size

The size of the quantized model is much smaller than the original model.

In [None]:
%ls -lh {tflite_models_dir}

### Check Model Accuracy

Next, we evaluate the accuracy of the quantized model on the test dataset and compared it with the original model.
Based on the results, we can see that the accuracy of the quantized model is very close to the original model.

In [None]:
# A helper function to evaluate the TF Lite model using "test" dataset.
def evaluate_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for test_image in test_images:
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  # Compare prediction results with ground truth labels to calculate accuracy.
  accurate_count = 0
  for index in range(len(prediction_digits)):
    if prediction_digits[index] == test_labels[index]:
      accurate_count += 1
  accuracy = accurate_count * 1.0 / len(prediction_digits)

  return accuracy


interpreter = tf.lite.Interpreter(model_path=str(tflite_model_file))
interpreter.allocate_tensors()
print("Original model accuracy = ", evaluate_model(interpreter))


interpreter_quant = tf.lite.Interpreter(model_path=str(tflite_model_quant_file))
interpreter_quant.allocate_tensors()
print("Quantized model accuracy = ", evaluate_model(interpreter_quant))

## Pruning

Pruning is a technique to reduce the size of the model by removing the weights that are not important. This is determined by the magnitude of the weights. We can use pruning while training the model to reduce the size of the model.

In [None]:
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# Compute end step to finish pruning after 2 epochs.
batch_size = 128
epochs = 2
validation_split = 0.1 # 10% of training set will be used for validation set. 

num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs

# Define model for pruning.
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                               final_sparsity=0.80,
                                                               begin_step=0,
                                                               end_step=end_step)
}

model_for_pruning = prune_low_magnitude(model, **pruning_params)

# `prune_low_magnitude` requires a recompile.
model_for_pruning.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

print(model_for_pruning.summary())

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
]

model_for_pruning.fit(train_images, train_labels,
                  batch_size=batch_size, epochs=epochs, validation_split=validation_split,
                  callbacks=callbacks)

### Compare Accuracy

We can see that the accuracy of the pruned model is very close to the original model.

In [None]:
_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)
_, model_for_pruning_accuracy = model_for_pruning.evaluate(
   test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy) 
print('Pruned test accuracy:', model_for_pruning_accuracy)

### Compare Model Size

Lastly, we compare the size of the pruned model with the original model.

In [None]:
model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

pruning_converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
pruned_tflite_model = pruning_converter.convert()
pruned_model_file = tflite_models_dir/"pruned_model.tflite"
pruned_model_file.write_bytes(pruned_tflite_model)

In [None]:
%ls -lh {tflite_models_dir}