def train()

in tf-2-data-parallelism/src/train_resnet_sdp_debug.py [0:0]


def train(args):
    # Load data from S3
    #     train_dir = os.environ.get('SM_CHANNEL_TRAIN')
    train_dir = args.train
    batch_size = args.batch_size
    dataset = get_train_data(train_dir, batch_size)

    model = get_resnet50(transfer_learning=True)

    loss_fn = tf.losses.SparseCategoricalCrossentropy()
    acc = tf.metrics.SparseCategoricalAccuracy(name='train_accuracy')

    # SMDataParallel: dist.size()
    # LR for 8 node run : 0.000125
    # LR for single node run : 0.001
    opt = tf.optimizers.Adam(args.learning_rate * dist.size())

    checkpoint_dir = os.environ['SM_MODEL_DIR']
    checkpoint = tf.train.Checkpoint(model=model, optimizer=opt)

    @tf.function
    def training_step(images, labels, first_batch):
        with tf.GradientTape() as tape:
            probs = model(images, training=True)
            loss_value = loss_fn(labels, probs)
            acc_value = acc(labels, probs)

        # SMDataParallel: Wrap tf.GradientTape with SMDataParallel's DistributedGradientTape
        tape = dist.DistributedGradientTape(tape)

        grads = tape.gradient(loss_value, model.trainable_variables)
        opt.apply_gradients(zip(grads, model.trainable_variables))

        if first_batch:
            # SMDataParallel: Broadcast model and optimizer variables
            dist.broadcast_variables(model.variables, root_rank=0)
            dist.broadcast_variables(opt.variables(), root_rank=0)

        # SMDataParallel: all_reduce call
        loss_value = dist.oob_allreduce(loss_value)  # Average the loss across workers
        acc_value = dist.oob_allreduce(acc_value)

        return loss_value, acc_value

    for epoch in range(args.epochs):
        for batch, (images, labels) in enumerate(dataset.take(10000 // dist.size())):
            loss_value, acc_value = training_step(images, labels, batch == 0)

            if batch % 100 == 0 and dist.rank() == 0:
                logger.info(
                    '*** Epoch %d   Step   #%d Accuracy: %.6f   Loss: %.6f ***' % (epoch, batch, acc_value, loss_value))

    # SMDataParallel: Save checkpoints only from master node.
    if dist.rank() == 0:
        model.save(os.path.join(checkpoint_dir, '1'))