in tf-2-data-parallelism/src/train_resnet_sdp_debug.py [0:0]
def train(args):
# Load data from S3
# train_dir = os.environ.get('SM_CHANNEL_TRAIN')
train_dir = args.train
batch_size = args.batch_size
dataset = get_train_data(train_dir, batch_size)
model = get_resnet50(transfer_learning=True)
loss_fn = tf.losses.SparseCategoricalCrossentropy()
acc = tf.metrics.SparseCategoricalAccuracy(name='train_accuracy')
# SMDataParallel: dist.size()
# LR for 8 node run : 0.000125
# LR for single node run : 0.001
opt = tf.optimizers.Adam(args.learning_rate * dist.size())
checkpoint_dir = os.environ['SM_MODEL_DIR']
checkpoint = tf.train.Checkpoint(model=model, optimizer=opt)
@tf.function
def training_step(images, labels, first_batch):
with tf.GradientTape() as tape:
probs = model(images, training=True)
loss_value = loss_fn(labels, probs)
acc_value = acc(labels, probs)
# SMDataParallel: Wrap tf.GradientTape with SMDataParallel's DistributedGradientTape
tape = dist.DistributedGradientTape(tape)
grads = tape.gradient(loss_value, model.trainable_variables)
opt.apply_gradients(zip(grads, model.trainable_variables))
if first_batch:
# SMDataParallel: Broadcast model and optimizer variables
dist.broadcast_variables(model.variables, root_rank=0)
dist.broadcast_variables(opt.variables(), root_rank=0)
# SMDataParallel: all_reduce call
loss_value = dist.oob_allreduce(loss_value) # Average the loss across workers
acc_value = dist.oob_allreduce(acc_value)
return loss_value, acc_value
for epoch in range(args.epochs):
for batch, (images, labels) in enumerate(dataset.take(10000 // dist.size())):
loss_value, acc_value = training_step(images, labels, batch == 0)
if batch % 100 == 0 and dist.rank() == 0:
logger.info(
'*** Epoch %d Step #%d Accuracy: %.6f Loss: %.6f ***' % (epoch, batch, acc_value, loss_value))
# SMDataParallel: Save checkpoints only from master node.
if dist.rank() == 0:
model.save(os.path.join(checkpoint_dir, '1'))