in tensorflow_examples/models/densenet/distributed_train.py [0:0]
def main(epochs,
enable_function,
buffer_size,
batch_size,
mode,
growth_rate,
output_classes,
depth_of_model=None,
num_of_blocks=None,
num_layers_in_each_block=None,
data_format='channels_last',
bottleneck=True,
compression=0.5,
weight_decay=1e-4,
dropout_rate=0.,
pool_initial=False,
include_top=True,
train_mode='custom_loop',
data_dir=None,
num_gpu=1):
devices = ['/device:GPU:{}'.format(i) for i in range(num_gpu)]
strategy = tf.distribute.MirroredStrategy(devices)
train_dataset, test_dataset, _ = utils.create_dataset(
buffer_size, batch_size, data_format, data_dir)
with strategy.scope():
model = densenet.DenseNet(
mode, growth_rate, output_classes, depth_of_model, num_of_blocks,
num_layers_in_each_block, data_format, bottleneck, compression,
weight_decay, dropout_rate, pool_initial, include_top)
trainer = Train(epochs, enable_function, model, batch_size, strategy)
train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset)
test_dist_dataset = strategy.experimental_distribute_dataset(test_dataset)
print('Training...')
if train_mode == 'custom_loop':
return trainer.custom_loop(train_dist_dataset,
test_dist_dataset,
strategy)
elif train_mode == 'keras_fit':
raise ValueError(
'`tf.distribute.Strategy` does not support subclassed models yet.')
else:
raise ValueError(
'Please enter either "keras_fit" or "custom_loop" as the argument.')