in community-content/vertex_model_garden/model_oss/keras/train.py [0:0]
def train_step(self, inputs):
images = inputs['images']
encoded_text = inputs['encoded_text']
batch_size = tf.shape(images)[0]
with tf.GradientTape() as tape:
# Project image into the latent space and sample from it.
latents = self.sample_from_encoder_outputs(
self.vae(images, training=False)
)
# Know more about the magic number here:
# https://keras.io/examples/generative/fine_tune_via_textual_inversion/
latents = latents * 0.18215
# Sample noise that we'll add to the latents.
noise = tf.random.normal(tf.shape(latents))
# Sample a random timestep for each image.
timesteps = tnp.random.randint(
0, self.noise_scheduler.train_timesteps, (batch_size,)
)
# Add noise to the latents according to the noise magnitude at each
# timestep (this is the forward diffusion process).
noisy_latents = self.noise_scheduler.add_noise(
tf.cast(latents, noise.dtype), noise, timesteps
)
# Get the target for loss depending on the prediction type
# just the sampled noise for now.
target = noise # noise_schedule.predict_epsilon == True
# Predict the noise residual and compute loss.
# pylint: disable=unnecessary-lambda
timestep_embedding = tf.map_fn(
lambda t: self.get_timestep_embedding(t), timesteps, dtype=tf.float32
)
timestep_embedding = tf.squeeze(timestep_embedding, 1)
model_pred = self.diffusion_model(
[noisy_latents, timestep_embedding, encoded_text], training=True
)
loss = self.compiled_loss(target, model_pred)
if self.use_mixed_precision:
loss = self.optimizer.get_scaled_loss(loss)
# Update parameters of the diffusion model.
trainable_vars = self.diffusion_model.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
if self.use_mixed_precision:
gradients = self.optimizer.get_unscaled_gradients(gradients)
gradients = [tf.clip_by_norm(g, self.max_grad_norm) for g in gradients]
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
return {m.name: m.result() for m in self.metrics}