in community-content/tf_keras_image_classification_distributed_multi_worker_with_vertex_sdk/trainer/task.py [0:0]
def main():
args = parse_args()
local_model_dir = './tmp/model'
local_tensorboard_log_dir = './tmp/logs'
local_checkpoint_dir = './tmp/checkpoints'
model_dir = args.model_dir or local_model_dir
tensorboard_log_dir = args.tensorboard_log_dir or local_tensorboard_log_dir
checkpoint_dir = args.checkpoint_dir or local_checkpoint_dir
gs_prefix = 'gs://'
gcsfuse_prefix = '/gcs/'
if model_dir and model_dir.startswith(gs_prefix):
model_dir = model_dir.replace(gs_prefix, gcsfuse_prefix)
if tensorboard_log_dir and tensorboard_log_dir.startswith(gs_prefix):
tensorboard_log_dir = tensorboard_log_dir.replace(gs_prefix, gcsfuse_prefix)
if checkpoint_dir and checkpoint_dir.startswith(gs_prefix):
checkpoint_dir = checkpoint_dir.replace(gs_prefix, gcsfuse_prefix)
num_worker, task_type, task_id = distribution_utils.setup()
print(f'task_type: {task_type}, '
f'task_id: {task_id}, '
f'num_worker: {num_worker} \n'
)
strategy = distribution_utils.get_strategy(num_worker=num_worker)
global_batch_size = args.batch_size * num_worker
print(f'Global batch size: {global_batch_size}')
train_ds = load_dataset(batch_size=global_batch_size)
if num_worker > 1:
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
train_ds = train_ds.with_options(options)
train_ds = train_ds.prefetch(tf.data.experimental.AUTOTUNE)
with strategy.scope():
model = build_model()
latest_ckpt = tf.train.latest_checkpoint(checkpoint_dir)
if latest_ckpt:
model.load_weights(latest_ckpt)
train(
model=model,
train_dataset=train_ds,
epochs=args.epochs,
tensorboard_log_dir=tensorboard_log_dir,
checkpoint_dir=checkpoint_dir
)
model_path = os.path.join(model_dir, str(args.model_version))
model_path = distribution_utils.write_filepath(model_path, task_type, task_id)
model.save(model_path)
print(f'Model version {args.model_version} is saved to {model_dir}')
distribution_utils.clean_up(task_type, task_id, model_path)
print(f'Tensorboard logs are saved to: {tensorboard_log_dir}')
return