in community-content/vertex_model_garden/model_oss/keras/train.py [0:0]
def main(_) -> None:
# _INPUT_CSV_PATH and _OUTPUT_MODEL_DIR should have the format as
# gs://<bucket_name>/<object_name>.
if _INPUT_CSV_PATH.value:
if not _INPUT_CSV_PATH.value.startswith(constants.GCS_URI_PREFIX):
raise ValueError('The input csv path should be a gcs path like gs://<>')
if _OUTPUT_MODEL_DIR.value:
if not _OUTPUT_MODEL_DIR.value.startswith(constants.GCS_URI_PREFIX):
raise ValueError('The output model dir should be a gcs path like gs://<>')
if _USE_MP.value:
keras.mixed_precision.set_global_policy('mixed_float16')
image_encoder = ImageEncoder(RESOLUTION, RESOLUTION)
diffusion_ft_trainer = Trainer(
diffusion_model=DiffusionModel(RESOLUTION, RESOLUTION, MAX_PROMPT_LENGTH),
# Remove the top layer from the encoder, which cuts off the variance and
# only returns the mean.
vae=tf.keras.Model(
image_encoder.input,
image_encoder.layers[-2].output,
),
noise_scheduler=NoiseScheduler(),
use_mixed_precision=_USE_MP.value,
)
optimizer = tf.keras.optimizers.experimental.AdamW(
learning_rate=_LEARNING_RATE.value,
weight_decay=_WEIGHT_DECAY.value,
beta_1=_BETA_1.value,
beta_2=_BETA_2.value,
epsilon=_EPSILON.value,
)
diffusion_ft_trainer.compile(optimizer=optimizer, loss='mse')
training_dataset = prepare_training_dataset(_INPUT_CSV_PATH.value)
# Note: gcsfuse does not work for Keras. We saves the trained models locally
# first, and then copy to gcs storages.
if not os.path.exists(constants.LOCAL_MODEL_DIR):
os.makedirs(constants.LOCAL_MODEL_DIR)
# The default saved model is in HDF5.
ckpt_path = os.path.join(constants.LOCAL_MODEL_DIR, 'saved_model.h5')
ckpt_callback = tf.keras.callbacks.ModelCheckpoint(
ckpt_path,
save_weights_only=True,
monitor='loss',
mode='min',
)
diffusion_ft_trainer.fit(
training_dataset, epochs=_EPOCHS.value, callbacks=[ckpt_callback]
)
# Copies the files in constants.LOCAL_MODEL_DIR to output_model_dir.
fileutils.upload_local_dir_to_gcs(
constants.LOCAL_MODEL_DIR, _OUTPUT_MODEL_DIR.value
)
return