in sagemaker/07_tensorflow_distributed_training_data_parallelism/scripts/train.py [0:0]
def fit(model, loss, opt, train_dataset, epochs, train_batch_size, max_steps=None):
pbar = tqdm(train_dataset)
for i, batch in enumerate(pbar):
with tf.GradientTape() as tape:
inputs, targets = batch
outputs = model(batch)
loss_value = loss(targets, outputs.logits)
if SDP_ENABLED:
tape = sdp.DistributedGradientTape(tape, sparse_as_dense=True)
grads = tape.gradient(loss_value, model.trainable_variables)
opt.apply_gradients(zip(grads, model.trainable_variables))
pbar.set_description(f"Loss: {loss_value:.4f}")
if SDP_ENABLED:
if i == 0:
sdp.broadcast_variables(model.variables, root_rank=0)
sdp.broadcast_variables(opt.variables(), root_rank=0)
first_batch = False
if max_steps and i >= max_steps:
break
train_results = {"loss": loss_value.numpy()}
return train_results