cli/jobs/single-step/tensorflow/mnist-distributed/src/train.py (84 lines of code) (raw):

# @title Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # Script adapted from: https://github.com/tensorflow/docs/blob/master/site/en/tutorials/distribute/multi_worker_with_keras.ipynb # ========================================================================= import tensorflow as tf import numpy as np import argparse import os, json def mnist_dataset(batch_size): (x_train, y_train), _ = tf.keras.datasets.mnist.load_data() # The `x` arrays are in uint8 and have values in the range [0, 255]. # We need to convert them to float32 with values in the range [0, 1] x_train = x_train / np.float32(255) y_train = y_train.astype(np.int64) train_dataset = ( tf.data.Dataset.from_tensor_slices((x_train, y_train)) .shuffle(60000) .repeat() .batch(batch_size) ) return train_dataset def build_and_compile_cnn_model(): model = tf.keras.Sequential( [ tf.keras.Input(shape=(28, 28)), tf.keras.layers.Reshape(target_shape=(28, 28, 1)), tf.keras.layers.Conv2D(32, 3, activation="relu"), tf.keras.layers.Flatten(), tf.keras.layers.Dense(128, activation="relu"), tf.keras.layers.Dense(10), ] ) model.compile( loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer=tf.keras.optimizers.SGD(learning_rate=0.001), metrics=["accuracy"], ) return model def _is_chief(task_type, task_id): # If `task_type` is None, this may be operating as single worker, which works # effectively as chief. return ( task_type is None or task_type == "chief" or (task_type == "worker" and task_id == 0) ) def _get_temp_dir(dirpath, task_id): base_dirpath = "workertemp_" + str(task_id) temp_dir = os.path.join(dirpath, base_dirpath) tf.io.gfile.makedirs(temp_dir) return temp_dir def write_filepath(filepath, task_type, task_id): dirpath = os.path.dirname(filepath) base = os.path.basename(filepath) if not _is_chief(task_type, task_id): dirpath = _get_temp_dir(dirpath, task_id) return os.path.join(dirpath, base) def fix_tf_config(): # This is necessary for TensorFlow 2.13 and later tf_config = json.loads(os.environ["TF_CONFIG"]) if "cluster" in tf_config: cluster = tf_config["cluster"] if "ps" in cluster and len(cluster["ps"]) == 0: cluster.pop("ps") os.environ["TF_CONFIG"] = json.dumps(tf_config) return tf_config def main(): parser = argparse.ArgumentParser() parser.add_argument("--epochs", type=int, default=3) parser.add_argument("--steps-per-epoch", type=int, default=70) parser.add_argument("--per-worker-batch-size", type=int, default=64) parser.add_argument( "--model-dir", type=str, default="outputs", help="directory to save the model to", ) args = parser.parse_args() tf_config = fix_tf_config() num_workers = len(tf_config["cluster"]["worker"]) strategy = tf.distribute.MultiWorkerMirroredStrategy() # Here the batch size scales up by number of workers since # `tf.data.Dataset.batch` expects the global batch size. global_batch_size = args.per_worker_batch_size * num_workers multi_worker_dataset = mnist_dataset(global_batch_size) with strategy.scope(): # Model building/compiling need to be within `strategy.scope()`. multi_worker_model = build_and_compile_cnn_model() # Keras' `model.fit()` trains the model with specified number of epochs and # number of steps per epoch. multi_worker_model.fit( multi_worker_dataset, epochs=args.epochs, steps_per_epoch=args.steps_per_epoch ) # Save the model task_type, task_id = (tf_config["task"]["type"], tf_config["task"]["index"]) write_model_path = write_filepath(args.model_dir, task_type, task_id) multi_worker_model.save(write_model_path) if __name__ == "__main__": main()