def training_step()

in distributed_training/train_code/tf_mnist_smdp.py [0:0]


def training_step(images, labels, first_batch):
    with tf.GradientTape() as tape:
        probs = mnist_model(images, training=True)
        loss_value = loss(labels, probs)

    ########################################################
    ####### 4. SageMaker Distributed Data Parallel  ########
    #######  - Optimize AllReduce operation         ########
    ########################################################
    # SMDataParallel: Wrap tf.GradientTape with SMDataParallel's DistributedGradientTape
    tape = smdp.DistributedGradientTape(tape)

    #######################################################
    
    grads = tape.gradient(loss_value, mnist_model.trainable_variables)
    opt.apply_gradients(zip(grads, mnist_model.trainable_variables))

    ########################################################
    ####### 5. SageMaker Distributed Data Parallel   #######
    #######  - Broadcast the initial model variables ####### 
    #######    from rank 0 to ranks 1 ~ n            #######
    ########################################################
    if first_batch:
        # SMDataParallel: Broadcast model and optimizer variables
        smdp.broadcast_variables(mnist_model.variables, root_rank=0)
        smdp.broadcast_variables(opt.variables(), root_rank=0)

    #######################################################

    # SMDataParallel: all_reduce call
    loss_value = smdp.oob_allreduce(loss_value)  # Average the loss across workers
    return loss_value