def main()

in tensorflow_examples/models/densenet/distributed_train.py [0:0]


def main(epochs,
         enable_function,
         buffer_size,
         batch_size,
         mode,
         growth_rate,
         output_classes,
         depth_of_model=None,
         num_of_blocks=None,
         num_layers_in_each_block=None,
         data_format='channels_last',
         bottleneck=True,
         compression=0.5,
         weight_decay=1e-4,
         dropout_rate=0.,
         pool_initial=False,
         include_top=True,
         train_mode='custom_loop',
         data_dir=None,
         num_gpu=1):

  devices = ['/device:GPU:{}'.format(i) for i in range(num_gpu)]
  strategy = tf.distribute.MirroredStrategy(devices)

  train_dataset, test_dataset, _ = utils.create_dataset(
      buffer_size, batch_size, data_format, data_dir)

  with strategy.scope():
    model = densenet.DenseNet(
        mode, growth_rate, output_classes, depth_of_model, num_of_blocks,
        num_layers_in_each_block, data_format, bottleneck, compression,
        weight_decay, dropout_rate, pool_initial, include_top)

    trainer = Train(epochs, enable_function, model, batch_size, strategy)

    train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
    test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)

    print('Training...')
    if train_mode == 'custom_loop':
      return trainer.custom_loop(train_dist_dataset,
                                 test_dist_dataset,
                                 strategy)
    elif train_mode == 'keras_fit':
      raise ValueError(
          '`tf.distribute.Strategy` does not support subclassed models yet.')
    else:
      raise ValueError(
          'Please enter either "keras_fit" or "custom_loop" as the argument.')