community-content/vertex_model_garden/model_oss/keras/train.py (252 lines of code) (raw):
"""Train Keras Stable Diffusion.
Most the codes below are from
https://keras.io/examples/generative/finetune_stable_diffusion/.
"""
import os
from absl import app
from absl import flags
from absl import logging
import keras_cv
# pylint: disable=g-importing-member
from keras_cv.models.stable_diffusion.clip_tokenizer import SimpleTokenizer
from keras_cv.models.stable_diffusion.diffusion_model import DiffusionModel
from keras_cv.models.stable_diffusion.image_encoder import ImageEncoder
from keras_cv.models.stable_diffusion.noise_scheduler import NoiseScheduler
from keras_cv.models.stable_diffusion.text_encoder import TextEncoder
import numpy as np
# The docker builds could not find pandas.
# pylint: disable=import-error
import pandas as pd
import tensorflow as tf
from tensorflow import keras
import tensorflow.experimental.numpy as tnp
from util import constants
from util import fileutils
_INPUT_CSV_PATH = flags.DEFINE_string(
'input_csv_path',
None,
'The input csv path.',
required=True,
)
_USE_MP = flags.DEFINE_bool(
'use_mp',
True,
'Enable mixed-precision training if the underlying GPU has tensor cores.',
)
_EPOCHS = flags.DEFINE_integer('epochs', 1, 'The number of epochs.')
_OUTPUT_MODEL_DIR = flags.DEFINE_string(
'output_model_dir',
None,
'The output model dir.',
required=True,
)
# These hyperparameters defaults come from this tutorial by Hugging Face:
# https://huggingface.co/docs/diffusers/training/text2image
_LEARNING_RATE = flags.DEFINE_float(
'learning_rate', 1e-5, 'The learning rate parameter for AdamW optimizer.'
)
_BETA_1 = flags.DEFINE_float(
'beta_1', 0.9, 'The beta_1 parameter for AdamW optimizer.'
)
_BETA_2 = flags.DEFINE_float(
'beta_2', 0.999, 'The beta_2 parameter for AdamW optimizer.'
)
_WEIGHT_DECAY = flags.DEFINE_float(
'weight_decay', 1e-2, 'The weight decay parameter for AdamW optimizer.'
)
_EPSILON = flags.DEFINE_float(
'epsilon', 1e-08, 'The epsilon parameter for AdamW optimizer.'
)
RESOLUTION = int(os.environ.get('RESOLUTION', 512))
# The padding token and maximum prompt length are specific to the text encoder.
# If you're using a different text encoder be sure to change them accordingly.
PADDING_TOKEN = 49407
MAX_PROMPT_LENGTH = 77
AUTO = tf.data.AUTOTUNE
POS_IDS = tf.convert_to_tensor([list(range(MAX_PROMPT_LENGTH))], dtype=tf.int32)
augmenter = keras.Sequential(
layers=[
keras_cv.layers.CenterCrop(RESOLUTION, RESOLUTION),
keras_cv.layers.RandomFlip(),
tf.keras.layers.Rescaling(scale=1.0 / 127.5, offset=-1),
]
)
text_encoder = TextEncoder(MAX_PROMPT_LENGTH)
def process_image(image_path, tokenized_text):
image = tf.io.read_file(image_path)
image = tf.io.decode_png(image, 3)
image = tf.image.resize(image, (RESOLUTION, RESOLUTION))
return image, tokenized_text
def apply_augmentation(image_batch, token_batch):
return augmenter(image_batch), token_batch
def run_text_encoder(image_batch, token_batch):
return (
image_batch,
token_batch,
text_encoder([token_batch, POS_IDS], training=False),
)
def prepare_dict(image_batch, token_batch, encoded_text_batch):
return {
'images': image_batch,
'tokens': token_batch,
'encoded_text': encoded_text_batch,
}
def prepare_dataset(image_paths, tokenized_texts, batch_size=1):
dataset = tf.data.Dataset.from_tensor_slices((image_paths, tokenized_texts))
dataset = dataset.shuffle(batch_size * 10)
dataset = dataset.map(process_image, num_parallel_calls=AUTO).batch(
batch_size
)
dataset = dataset.map(apply_augmentation, num_parallel_calls=AUTO)
dataset = dataset.map(run_text_encoder, num_parallel_calls=AUTO)
dataset = dataset.map(prepare_dict, num_parallel_calls=AUTO)
return dataset.prefetch(AUTO)
def prepare_training_dataset(dataset_csv):
"""Prepares training datasets."""
if dataset_csv.startswith(constants.GCS_URI_PREFIX):
if not os.path.exists(constants.LOCAL_DATA_DIR):
os.makedirs(constants.LOCAL_DATA_DIR)
logging.info(
'Start to download data from %s to %s.',
os.path.dirname(dataset_csv),
constants.LOCAL_DATA_DIR,
)
fileutils.download_gcs_dir_to_local(
os.path.dirname(dataset_csv), constants.LOCAL_DATA_DIR
)
data_frame = pd.read_csv(
os.path.join(constants.LOCAL_DATA_DIR, os.path.basename(dataset_csv))
)
data_frame['image_path'] = data_frame['image_path'].apply(
lambda x: os.path.join(constants.LOCAL_DATA_DIR, x)
)
else:
# Keeps the following codes for experiments with
# https://keras.io/examples/generative/finetune_stable_diffusion/.
data_path = tf.keras.utils.get_file(origin=dataset_csv, untar=True)
data_frame = pd.read_csv(os.path.join(data_path, 'data.csv'))
data_frame['image_path'] = data_frame['image_path'].apply(
lambda x: os.path.join(data_path, x)
)
data_frame.head()
# Load the tokenizer.
tokenizer = SimpleTokenizer()
# Method to tokenize and pad the tokens.
def process_text(caption):
tokens = tokenizer.encode(caption)
tokens = tokens + [PADDING_TOKEN] * (MAX_PROMPT_LENGTH - len(tokens))
return np.array(tokens)
# Collate the tokenized captions into an array.
tokenized_texts = np.empty((len(data_frame), MAX_PROMPT_LENGTH))
all_captions = list(data_frame['caption'].values)
for i, caption in enumerate(all_captions):
tokenized_texts[i] = process_text(caption)
# Prepare the dataset.
training_dataset = prepare_dataset(
np.array(data_frame['image_path']), tokenized_texts, batch_size=4
)
return training_dataset
class Trainer(tf.keras.Model):
"""The trainer for Keras Stable Diffusion."""
# Reference:
# https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py
def __init__(
self,
diffusion_model,
vae,
noise_scheduler,
use_mixed_precision=False,
max_grad_norm=1.0,
**kwargs,
):
super().__init__(**kwargs)
self.diffusion_model = diffusion_model
self.vae = vae
self.noise_scheduler = noise_scheduler
self.max_grad_norm = max_grad_norm
self.use_mixed_precision = use_mixed_precision
self.vae.trainable = False
def train_step(self, inputs):
images = inputs['images']
encoded_text = inputs['encoded_text']
batch_size = tf.shape(images)[0]
with tf.GradientTape() as tape:
# Project image into the latent space and sample from it.
latents = self.sample_from_encoder_outputs(
self.vae(images, training=False)
)
# Know more about the magic number here:
# https://keras.io/examples/generative/fine_tune_via_textual_inversion/
latents = latents * 0.18215
# Sample noise that we'll add to the latents.
noise = tf.random.normal(tf.shape(latents))
# Sample a random timestep for each image.
timesteps = tnp.random.randint(
0, self.noise_scheduler.train_timesteps, (batch_size,)
)
# Add noise to the latents according to the noise magnitude at each
# timestep (this is the forward diffusion process).
noisy_latents = self.noise_scheduler.add_noise(
tf.cast(latents, noise.dtype), noise, timesteps
)
# Get the target for loss depending on the prediction type
# just the sampled noise for now.
target = noise # noise_schedule.predict_epsilon == True
# Predict the noise residual and compute loss.
# pylint: disable=unnecessary-lambda
timestep_embedding = tf.map_fn(
lambda t: self.get_timestep_embedding(t), timesteps, dtype=tf.float32
)
timestep_embedding = tf.squeeze(timestep_embedding, 1)
model_pred = self.diffusion_model(
[noisy_latents, timestep_embedding, encoded_text], training=True
)
loss = self.compiled_loss(target, model_pred)
if self.use_mixed_precision:
loss = self.optimizer.get_scaled_loss(loss)
# Update parameters of the diffusion model.
trainable_vars = self.diffusion_model.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
if self.use_mixed_precision:
gradients = self.optimizer.get_unscaled_gradients(gradients)
gradients = [tf.clip_by_norm(g, self.max_grad_norm) for g in gradients]
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
return {m.name: m.result() for m in self.metrics}
def get_timestep_embedding(self, timestep, dim=320, max_period=10000):
half = dim // 2
log_max_preiod = tf.math.log(tf.cast(max_period, tf.float32))
# The docker builds could not support unary `-`.
# pylint: disable=invalid-unary-operand-type
freqs = tf.math.exp(
-log_max_preiod * tf.range(0, half, dtype=tf.float32) / half
)
args = tf.convert_to_tensor([timestep], dtype=tf.float32) * freqs
embedding = tf.concat([tf.math.cos(args), tf.math.sin(args)], 0)
embedding = tf.reshape(embedding, [1, -1])
return embedding
def sample_from_encoder_outputs(self, outputs):
mean, logvar = tf.split(outputs, 2, axis=-1)
logvar = tf.clip_by_value(logvar, -30.0, 20.0)
std = tf.exp(0.5 * logvar)
sample = tf.random.normal(tf.shape(mean), dtype=mean.dtype)
return mean + std * sample
def save_weights(
self, filepath, overwrite=True, save_format=None, options=None
):
# Overriding this method will allow us to use the `ModelCheckpoint`
# callback directly with this trainer class. In this case, it will
# only checkpoint the `diffusion_model` since that's what we're training
# during fine-tuning.
self.diffusion_model.save_weights(
filepath=filepath,
overwrite=overwrite,
save_format=save_format,
options=options,
)
def main(_) -> None:
# _INPUT_CSV_PATH and _OUTPUT_MODEL_DIR should have the format as
# gs://<bucket_name>/<object_name>.
if _INPUT_CSV_PATH.value:
if not _INPUT_CSV_PATH.value.startswith(constants.GCS_URI_PREFIX):
raise ValueError('The input csv path should be a gcs path like gs://<>')
if _OUTPUT_MODEL_DIR.value:
if not _OUTPUT_MODEL_DIR.value.startswith(constants.GCS_URI_PREFIX):
raise ValueError('The output model dir should be a gcs path like gs://<>')
if _USE_MP.value:
keras.mixed_precision.set_global_policy('mixed_float16')
image_encoder = ImageEncoder(RESOLUTION, RESOLUTION)
diffusion_ft_trainer = Trainer(
diffusion_model=DiffusionModel(RESOLUTION, RESOLUTION, MAX_PROMPT_LENGTH),
# Remove the top layer from the encoder, which cuts off the variance and
# only returns the mean.
vae=tf.keras.Model(
image_encoder.input,
image_encoder.layers[-2].output,
),
noise_scheduler=NoiseScheduler(),
use_mixed_precision=_USE_MP.value,
)
optimizer = tf.keras.optimizers.experimental.AdamW(
learning_rate=_LEARNING_RATE.value,
weight_decay=_WEIGHT_DECAY.value,
beta_1=_BETA_1.value,
beta_2=_BETA_2.value,
epsilon=_EPSILON.value,
)
diffusion_ft_trainer.compile(optimizer=optimizer, loss='mse')
training_dataset = prepare_training_dataset(_INPUT_CSV_PATH.value)
# Note: gcsfuse does not work for Keras. We saves the trained models locally
# first, and then copy to gcs storages.
if not os.path.exists(constants.LOCAL_MODEL_DIR):
os.makedirs(constants.LOCAL_MODEL_DIR)
# The default saved model is in HDF5.
ckpt_path = os.path.join(constants.LOCAL_MODEL_DIR, 'saved_model.h5')
ckpt_callback = tf.keras.callbacks.ModelCheckpoint(
ckpt_path,
save_weights_only=True,
monitor='loss',
mode='min',
)
diffusion_ft_trainer.fit(
training_dataset, epochs=_EPOCHS.value, callbacks=[ckpt_callback]
)
# Copies the files in constants.LOCAL_MODEL_DIR to output_model_dir.
fileutils.upload_local_dir_to_gcs(
constants.LOCAL_MODEL_DIR, _OUTPUT_MODEL_DIR.value
)
return
if __name__ == '__main__':
app.run(main)