def main()

in community-content/vertex_model_garden/model_oss/keras/train.py [0:0]


def main(_) -> None:
  # _INPUT_CSV_PATH and _OUTPUT_MODEL_DIR should have the format as
  # gs://<bucket_name>/<object_name>.
  if _INPUT_CSV_PATH.value:
    if not _INPUT_CSV_PATH.value.startswith(constants.GCS_URI_PREFIX):
      raise ValueError('The input csv path should be a gcs path like gs://<>')
  if _OUTPUT_MODEL_DIR.value:
    if not _OUTPUT_MODEL_DIR.value.startswith(constants.GCS_URI_PREFIX):
      raise ValueError('The output model dir should be a gcs path like gs://<>')

  if _USE_MP.value:
    keras.mixed_precision.set_global_policy('mixed_float16')

  image_encoder = ImageEncoder(RESOLUTION, RESOLUTION)
  diffusion_ft_trainer = Trainer(
      diffusion_model=DiffusionModel(RESOLUTION, RESOLUTION, MAX_PROMPT_LENGTH),
      # Remove the top layer from the encoder, which cuts off the variance and
      # only returns the mean.
      vae=tf.keras.Model(
          image_encoder.input,
          image_encoder.layers[-2].output,
      ),
      noise_scheduler=NoiseScheduler(),
      use_mixed_precision=_USE_MP.value,
  )

  optimizer = tf.keras.optimizers.experimental.AdamW(
      learning_rate=_LEARNING_RATE.value,
      weight_decay=_WEIGHT_DECAY.value,
      beta_1=_BETA_1.value,
      beta_2=_BETA_2.value,
      epsilon=_EPSILON.value,
  )
  diffusion_ft_trainer.compile(optimizer=optimizer, loss='mse')

  training_dataset = prepare_training_dataset(_INPUT_CSV_PATH.value)

  # Note: gcsfuse does not work for Keras. We saves the trained models locally
  # first, and then copy to gcs storages.
  if not os.path.exists(constants.LOCAL_MODEL_DIR):
    os.makedirs(constants.LOCAL_MODEL_DIR)
  # The default saved model is in HDF5.
  ckpt_path = os.path.join(constants.LOCAL_MODEL_DIR, 'saved_model.h5')
  ckpt_callback = tf.keras.callbacks.ModelCheckpoint(
      ckpt_path,
      save_weights_only=True,
      monitor='loss',
      mode='min',
  )
  diffusion_ft_trainer.fit(
      training_dataset, epochs=_EPOCHS.value, callbacks=[ckpt_callback]
  )

  # Copies the files in constants.LOCAL_MODEL_DIR to output_model_dir.
  fileutils.upload_local_dir_to_gcs(
      constants.LOCAL_MODEL_DIR, _OUTPUT_MODEL_DIR.value
  )

  return