def main()

in community-content/tf_keras_text_classification_distributed_single_worker_gpus_with_gcloud_local_run_and_vertex_sdk/trainer/task.py [0:0]


def main():

  args = parse_args()

  local_data_dir = './tmp/data'

  local_model_dir = './tmp/model'
  local_checkpoint_dir = './tmp/checkpoints'
  local_tensorboard_log_dir = './tmp/logs'

  model_dir = args.model_dir or local_model_dir
  tensorboard_log_dir = args.tensorboard_log_dir or local_tensorboard_log_dir
  checkpoint_dir = args.checkpoint_dir or local_checkpoint_dir

  gs_prefix = 'gs://'
  gcsfuse_prefix = '/gcs/'
  if model_dir and model_dir.startswith(gs_prefix):
    model_dir = model_dir.replace(gs_prefix, gcsfuse_prefix)
  if tensorboard_log_dir and tensorboard_log_dir.startswith(gs_prefix):
    tensorboard_log_dir = tensorboard_log_dir.replace(gs_prefix, gcsfuse_prefix)
  if checkpoint_dir and checkpoint_dir.startswith(gs_prefix):
    checkpoint_dir = checkpoint_dir.replace(gs_prefix, gcsfuse_prefix)

  class_names = ['csharp', 'java', 'javascript', 'python']
  class_indices = dict(zip(class_names, range(len(class_names))))
  num_classes = len(class_names)
  print(f' class names: {class_names}')
  print(f' class indices: {class_indices}')
  print(f' num classes: {num_classes}')

  strategy = distribution_utils.get_distribution_mirrored_strategy(
      num_gpus=args.num_gpus)
  print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

  global_batch_size = args.batch_size * strategy.num_replicas_in_sync
  print(f'Global batch size: {global_batch_size}')

  dataset_dir = download_data(local_data_dir)
  raw_train_ds, raw_val_ds, raw_test_ds = load_dataset(dataset_dir, global_batch_size)

  vectorize_layer = TextVectorization(
      max_tokens=VOCAB_SIZE,
      output_mode='int',
      output_sequence_length=MAX_SEQUENCE_LENGTH)

  train_text = raw_train_ds.map(lambda text, labels: text)
  vectorize_layer.adapt(train_text)
  print('The vectorize_layer is adapted')

  def vectorize_text(text, label):
    text = tf.expand_dims(text, -1)
    return vectorize_layer(text), label

  # Retrieve a batch (of 32 reviews and labels) from the dataset
  text_batch, label_batch = next(iter(raw_train_ds))
  first_question, first_label = text_batch[0], label_batch[0]
  print("Question", first_question)
  print("Label", first_label)
  print("Vectorized question:", vectorize_text(first_question, first_label)[0])

  train_ds = raw_train_ds.map(vectorize_text)
  val_ds = raw_val_ds.map(vectorize_text)
  test_ds = raw_test_ds.map(vectorize_text)

  AUTOTUNE = tf.data.AUTOTUNE

  def configure_dataset(dataset):
    return dataset.cache().prefetch(buffer_size=AUTOTUNE)

  train_ds = configure_dataset(train_ds)
  val_ds = configure_dataset(val_ds)
  test_ds = configure_dataset(test_ds)

  print('Build model')
  loss = losses.SparseCategoricalCrossentropy(from_logits=True),
  optimizer = 'adam'
  metrics = ['accuracy']

  with strategy.scope():
    model = build_model(
        num_classes=num_classes,
        loss=loss,
        optimizer=optimizer,
        metrics=metrics,
    )

  train(
      model=model,
      train_dataset=train_ds,
      validation_dataset=val_ds,
      epochs=args.epochs,
      tensorboard_log_dir=tensorboard_log_dir,
      checkpoint_dir=checkpoint_dir
  )

  test_loss, test_accuracy = model.evaluate(test_ds)
  print("Int model accuracy: {:2.2%}".format(test_accuracy))

  with strategy.scope():
    export_model = tf.keras.Sequential(
        [vectorize_layer, model,
         layers.Activation('softmax')])

    export_model.compile(
        loss=losses.SparseCategoricalCrossentropy(from_logits=False),
        optimizer='adam',
        metrics=['accuracy'])

  loss, accuracy = export_model.evaluate(raw_test_ds)
  print("Accuracy: {:2.2%}".format(accuracy))

  model_path = os.path.join(model_dir, str(args.model_version))
  model.save(model_path)
  print(f'Model version {args.model_version} is saved to {model_dir}')

  print(f'Tensorboard logs are saved to: {tensorboard_log_dir}')

  print(f'Checkpoints are saved to: {checkpoint_dir}')

  return