in distributed_training/train_code/tf_mnist_smdp.py [0:0]
def training_step(images, labels, first_batch):
with tf.GradientTape() as tape:
probs = mnist_model(images, training=True)
loss_value = loss(labels, probs)
########################################################
####### 4. SageMaker Distributed Data Parallel ########
####### - Optimize AllReduce operation ########
########################################################
# SMDataParallel: Wrap tf.GradientTape with SMDataParallel's DistributedGradientTape
tape = smdp.DistributedGradientTape(tape)
#######################################################
grads = tape.gradient(loss_value, mnist_model.trainable_variables)
opt.apply_gradients(zip(grads, mnist_model.trainable_variables))
########################################################
####### 5. SageMaker Distributed Data Parallel #######
####### - Broadcast the initial model variables #######
####### from rank 0 to ranks 1 ~ n #######
########################################################
if first_batch:
# SMDataParallel: Broadcast model and optimizer variables
smdp.broadcast_variables(mnist_model.variables, root_rank=0)
smdp.broadcast_variables(opt.variables(), root_rank=0)
#######################################################
# SMDataParallel: all_reduce call
loss_value = smdp.oob_allreduce(loss_value) # Average the loss across workers
return loss_value