def main()

in community-content/tf_keras_image_classification_distributed_multi_worker_with_vertex_sdk/trainer/task.py [0:0]


def main():

  args = parse_args()

  local_model_dir = './tmp/model'
  local_tensorboard_log_dir = './tmp/logs'
  local_checkpoint_dir = './tmp/checkpoints'

  model_dir = args.model_dir or local_model_dir
  tensorboard_log_dir = args.tensorboard_log_dir or local_tensorboard_log_dir
  checkpoint_dir = args.checkpoint_dir or local_checkpoint_dir

  gs_prefix = 'gs://'
  gcsfuse_prefix = '/gcs/'
  if model_dir and model_dir.startswith(gs_prefix):
    model_dir = model_dir.replace(gs_prefix, gcsfuse_prefix)
  if tensorboard_log_dir and tensorboard_log_dir.startswith(gs_prefix):
    tensorboard_log_dir = tensorboard_log_dir.replace(gs_prefix, gcsfuse_prefix)
  if checkpoint_dir and checkpoint_dir.startswith(gs_prefix):
    checkpoint_dir = checkpoint_dir.replace(gs_prefix, gcsfuse_prefix)

  num_worker, task_type, task_id = distribution_utils.setup()
  print(f'task_type: {task_type}, '
        f'task_id: {task_id}, '
        f'num_worker: {num_worker} \n'
        )

  strategy = distribution_utils.get_strategy(num_worker=num_worker)

  global_batch_size = args.batch_size * num_worker
  print(f'Global batch size: {global_batch_size}')

  train_ds = load_dataset(batch_size=global_batch_size)

  if num_worker > 1:
    options = tf.data.Options()
    options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
    train_ds = train_ds.with_options(options)

  train_ds = train_ds.prefetch(tf.data.experimental.AUTOTUNE)

  with strategy.scope():
    model = build_model()
    latest_ckpt = tf.train.latest_checkpoint(checkpoint_dir)
    if latest_ckpt:
      model.load_weights(latest_ckpt)

  train(
      model=model,
      train_dataset=train_ds,
      epochs=args.epochs,
      tensorboard_log_dir=tensorboard_log_dir,
      checkpoint_dir=checkpoint_dir
  )

  model_path = os.path.join(model_dir, str(args.model_version))
  model_path = distribution_utils.write_filepath(model_path, task_type, task_id)
  model.save(model_path)
  print(f'Model version {args.model_version} is saved to {model_dir}')

  distribution_utils.clean_up(task_type, task_id, model_path)

  print(f'Tensorboard logs are saved to: {tensorboard_log_dir}')

  return