5-appinfra/modules/hpc-ai-training-infra/tensorflow_mnist_train_distributed.py (73 lines of code) (raw):

# Copyright 2025 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # flake8: noqa import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import tensorflow_datasets as tfds import tensorflow as tf import keras import glob datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True) mnist_train, mnist_test = datasets['train'], datasets['test'] print('******************') print('MNIST TRAINING JOB') print('******************') strategy = tf.distribute.MirroredStrategy() print('Number of devices: {}'.format(strategy.num_replicas_in_sync)) num_train_examples = info.splits['train'].num_examples num_test_examples = info.splits['test'].num_examples BUFFER_SIZE = 10000 BATCH_SIZE_PER_REPLICA = 64 BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync def scale(image, label): image = tf.cast(image, tf.float32) image /= 255 return image, label train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE) eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE) with strategy.scope(): model = keras.Sequential([ keras.Input(shape=(28, 28, 1)), keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"), keras.layers.MaxPooling2D(), keras.layers.Flatten(), keras.layers.Dense(64, activation='relu'), keras.layers.Dense(10) ]) model.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=keras.optimizers.Adam(), metrics=['accuracy']) # Define the checkpoint directory to store the checkpoints. checkpoint_dir = './training_checkpoints' # Define the name of the checkpoint files. checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}.weights.h5") def decay(epoch): if epoch < 3: return 1e-3 elif epoch >= 3 and epoch < 7: return 1e-4 else: return 1e-5 # Define a callback for printing the learning rate at the end of each epoch. class PrintLR(keras.callbacks.Callback): def on_epoch_end(self, epoch, logs=None): print('\nLearning rate for epoch {} is {}'.format(epoch + 1, model.optimizer.learning_rate.numpy())) callbacks = [ tf.keras.callbacks.TensorBoard(log_dir='./logs'), tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix, save_weights_only=True), tf.keras.callbacks.LearningRateScheduler(decay), PrintLR() ] EPOCHS = 12 model.fit(train_dataset, epochs=EPOCHS, callbacks=callbacks) # Function to find the latest .h5 file def find_latest_h5_checkpoint(checkpoint_dir): list_of_files = glob.glob(f'{checkpoint_dir}/*.h5') if list_of_files: latest_file = max(list_of_files, key=os.path.getctime) return latest_file else: return None model.load_weights(find_latest_h5_checkpoint(checkpoint_dir)) eval_loss, eval_acc = model.evaluate(eval_dataset) print('Eval loss: {}, Eval accuracy: {}'.format(eval_loss, eval_acc)) path = '/data/mnist_saved_model' os.makedirs(path, exist_ok=True) model_file = '/data/mnist_saved_model/mnist.keras' model.save(model_file) print('Training finished. Model saved')