def fit()

in sagemaker/07_tensorflow_distributed_training_data_parallelism/scripts/train.py [0:0]


def fit(model, loss, opt, train_dataset, epochs, train_batch_size, max_steps=None):
    pbar = tqdm(train_dataset)
    for i, batch in enumerate(pbar):
        with tf.GradientTape() as tape:
            inputs, targets = batch
            outputs = model(batch)
            loss_value = loss(targets, outputs.logits)

        if SDP_ENABLED:
            tape = sdp.DistributedGradientTape(tape, sparse_as_dense=True)

        grads = tape.gradient(loss_value, model.trainable_variables)
        opt.apply_gradients(zip(grads, model.trainable_variables))

        pbar.set_description(f"Loss: {loss_value:.4f}")

        if SDP_ENABLED:
            if i == 0:
                sdp.broadcast_variables(model.variables, root_rank=0)
                sdp.broadcast_variables(opt.variables(), root_rank=0)
                first_batch = False

        if max_steps and i >= max_steps:
            break

    train_results = {"loss": loss_value.numpy()}
    return train_results