in examples/keras_spark3_rossmann.py [0:0]
def train_fn(model_bytes):
# Make sure pyarrow is referenced before anything else to avoid segfault due to conflict
# with TensorFlow libraries. Use `pa` package reference to ensure it's loaded before
# functions like `deserialize_model` which are implemented at the top level.
# See https://jira.apache.org/jira/browse/ARROW-3346
pa
import atexit
import horovod.tensorflow.keras as hvd
from horovod.spark.task import get_available_devices
import os
from petastorm import make_batch_reader
from petastorm.tf_utils import make_petastorm_dataset
import tempfile
import tensorflow as tf
import tensorflow.keras.backend as K
import shutil
# Horovod: initialize Horovod inside the trainer.
hvd.init()
# Horovod: pin GPU to be used to process local rank (one GPU per process), if GPUs are available.
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.visible_device_list = get_available_devices()[0]
K.set_session(tf.Session(config=config))
# Horovod: restore from checkpoint, use hvd.load_model under the hood.
model = deserialize_model(model_bytes, hvd.load_model)
# Horovod: adjust learning rate based on number of processes.
scaled_lr = K.get_value(model.optimizer.lr) * hvd.size()
K.set_value(model.optimizer.lr, scaled_lr)
# Horovod: print summary logs on the first worker.
verbose = 2 if hvd.rank() == 0 else 0
callbacks = [
# Horovod: broadcast initial variable states from rank 0 to all other processes.
# This is necessary to ensure consistent initialization of all workers when
# training is started with random weights or restored from a checkpoint.
hvd.callbacks.BroadcastGlobalVariablesCallback(root_rank=0),
# Horovod: average metrics among workers at the end of every epoch.
#
# Note: This callback must be in the list before the ReduceLROnPlateau,
# TensorBoard, or other metrics-based callbacks.
hvd.callbacks.MetricAverageCallback(),
# Horovod: using `lr = 1.0 * hvd.size()` from the very beginning leads to worse final
# accuracy. Scale the learning rate `lr = 1.0` ---> `lr = 1.0 * hvd.size()` during
# the first five epochs. See https://arxiv.org/abs/1706.02677 for details.
hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=5, initial_lr=scaled_lr, verbose=verbose),
# Reduce LR if the metric is not improved for 10 epochs, and stop training
# if it has not improved for 20 epochs.
tf.keras.callbacks.ReduceLROnPlateau(monitor='val_exp_rmspe', patience=10, verbose=verbose),
tf.keras.callbacks.EarlyStopping(monitor='val_exp_rmspe', mode='min', patience=20, verbose=verbose),
tf.keras.callbacks.TerminateOnNaN()
]
# Model checkpoint location.
ckpt_dir = tempfile.mkdtemp()
ckpt_file = os.path.join(ckpt_dir, 'checkpoint.h5')
atexit.register(lambda: shutil.rmtree(ckpt_dir))
# Horovod: save checkpoints only on the first worker to prevent other workers from corrupting them.
if hvd.rank() == 0:
callbacks.append(tf.keras.callbacks.ModelCheckpoint(ckpt_file, monitor='val_exp_rmspe', mode='min',
save_best_only=True))
# Make Petastorm readers.
with make_batch_reader('%s/train_df.parquet' % args.data_dir, num_epochs=None,
cur_shard=hvd.rank(), shard_count=hvd.size(),
hdfs_driver=PETASTORM_HDFS_DRIVER) as train_reader:
with make_batch_reader('%s/val_df.parquet' % args.data_dir, num_epochs=None,
cur_shard=hvd.rank(), shard_count=hvd.size(),
hdfs_driver=PETASTORM_HDFS_DRIVER) as val_reader:
# Convert readers to tf.data.Dataset.
train_ds = make_petastorm_dataset(train_reader) \
.apply(tf.data.experimental.unbatch()) \
.shuffle(int(train_rows / hvd.size())) \
.batch(args.batch_size) \
.map(lambda x: (tuple(getattr(x, col) for col in all_cols), tf.log(x.Sales)))
val_ds = make_petastorm_dataset(val_reader) \
.apply(tf.data.experimental.unbatch()) \
.batch(args.batch_size) \
.map(lambda x: (tuple(getattr(x, col) for col in all_cols), tf.log(x.Sales)))
history = model.fit(train_ds,
validation_data=val_ds,
steps_per_epoch=int(train_rows / args.batch_size / hvd.size()),
validation_steps=int(val_rows / args.batch_size / hvd.size()),
callbacks=callbacks,
verbose=verbose,
epochs=args.epochs)
# Dataset API usage currently displays a wall of errors upon termination.
# This global model registration ensures clean termination.
# Tracked in https://github.com/tensorflow/tensorflow/issues/24570
globals()['_DATASET_FINALIZATION_HACK'] = model
if hvd.rank() == 0:
with open(ckpt_file, 'rb') as f:
return history.history, f.read()