def custom_loop()

in tensorflow_examples/models/densenet/distributed_train.py [0:0]


  def custom_loop(self, train_dist_dataset, test_dist_dataset,
                  strategy):
    """Custom training and testing loop.

    Args:
      train_dist_dataset: Training dataset created using strategy.
      test_dist_dataset: Testing dataset created using strategy.
      strategy: Distribution strategy.

    Returns:
      train_loss, train_accuracy, test_loss, test_accuracy
    """

    def distributed_train_epoch(ds):
      total_loss = 0.0
      num_train_batches = 0.0
      for one_batch in ds:
        per_replica_loss = strategy.run(self.train_step, args=(one_batch,))
        total_loss += strategy.reduce(
            tf.distribute.ReduceOp.SUM, per_replica_loss, axis=None)
        num_train_batches += 1
      return total_loss, num_train_batches

    def distributed_test_epoch(ds):
      num_test_batches = 0.0
      for one_batch in ds:
        strategy.run(self.test_step, args=(one_batch,))
        num_test_batches += 1
      return self.test_loss_metric.result(), num_test_batches

    if self.enable_function:
      distributed_train_epoch = tf.function(distributed_train_epoch)
      distributed_test_epoch = tf.function(distributed_test_epoch)

    for epoch in range(self.epochs):
      self.optimizer.learning_rate = self.decay(epoch)

      train_total_loss, num_train_batches = distributed_train_epoch(
          train_dist_dataset)
      test_total_loss, num_test_batches = distributed_test_epoch(
          test_dist_dataset)

      template = ('Epoch: {}, Train Loss: {}, Train Accuracy: {}, '
                  'Test Loss: {}, Test Accuracy: {}')

      print(
          template.format(epoch,
                          train_total_loss / num_train_batches,
                          self.train_acc_metric.result(),
                          test_total_loss / num_test_batches,
                          self.test_acc_metric.result()))

      if epoch != self.epochs - 1:
        self.train_acc_metric.reset_states()
        self.test_acc_metric.reset_states()

    return (train_total_loss / num_train_batches,
            self.train_acc_metric.result().numpy(),
            test_total_loss / num_test_batches,
            self.test_acc_metric.result().numpy())