in tensorflow_examples/models/densenet/distributed_train.py [0:0]
def custom_loop(self, train_dist_dataset, test_dist_dataset,
strategy):
"""Custom training and testing loop.
Args:
train_dist_dataset: Training dataset created using strategy.
test_dist_dataset: Testing dataset created using strategy.
strategy: Distribution strategy.
Returns:
train_loss, train_accuracy, test_loss, test_accuracy
"""
def distributed_train_epoch(ds):
total_loss = 0.0
num_train_batches = 0.0
for one_batch in ds:
per_replica_loss = strategy.run(self.train_step, args=(one_batch,))
total_loss += strategy.reduce(
tf.distribute.ReduceOp.SUM, per_replica_loss, axis=None)
num_train_batches += 1
return total_loss, num_train_batches
def distributed_test_epoch(ds):
num_test_batches = 0.0
for one_batch in ds:
strategy.run(self.test_step, args=(one_batch,))
num_test_batches += 1
return self.test_loss_metric.result(), num_test_batches
if self.enable_function:
distributed_train_epoch = tf.function(distributed_train_epoch)
distributed_test_epoch = tf.function(distributed_test_epoch)
for epoch in range(self.epochs):
self.optimizer.learning_rate = self.decay(epoch)
train_total_loss, num_train_batches = distributed_train_epoch(
train_dist_dataset)
test_total_loss, num_test_batches = distributed_test_epoch(
test_dist_dataset)
template = ('Epoch: {}, Train Loss: {}, Train Accuracy: {}, '
'Test Loss: {}, Test Accuracy: {}')
print(
template.format(epoch,
train_total_loss / num_train_batches,
self.train_acc_metric.result(),
test_total_loss / num_test_batches,
self.test_acc_metric.result()))
if epoch != self.epochs - 1:
self.train_acc_metric.reset_states()
self.test_acc_metric.reset_states()
return (train_total_loss / num_train_batches,
self.train_acc_metric.result().numpy(),
test_total_loss / num_test_batches,
self.test_acc_metric.result().numpy())