def train_step()

in community-content/vertex_model_garden/model_oss/keras/train.py [0:0]


  def train_step(self, inputs):
    images = inputs['images']
    encoded_text = inputs['encoded_text']
    batch_size = tf.shape(images)[0]

    with tf.GradientTape() as tape:
      # Project image into the latent space and sample from it.
      latents = self.sample_from_encoder_outputs(
          self.vae(images, training=False)
      )
      # Know more about the magic number here:
      # https://keras.io/examples/generative/fine_tune_via_textual_inversion/
      latents = latents * 0.18215

      # Sample noise that we'll add to the latents.
      noise = tf.random.normal(tf.shape(latents))

      # Sample a random timestep for each image.
      timesteps = tnp.random.randint(
          0, self.noise_scheduler.train_timesteps, (batch_size,)
      )

      # Add noise to the latents according to the noise magnitude at each
      # timestep (this is the forward diffusion process).
      noisy_latents = self.noise_scheduler.add_noise(
          tf.cast(latents, noise.dtype), noise, timesteps
      )

      # Get the target for loss depending on the prediction type
      # just the sampled noise for now.
      target = noise  # noise_schedule.predict_epsilon == True

      # Predict the noise residual and compute loss.
      # pylint: disable=unnecessary-lambda
      timestep_embedding = tf.map_fn(
          lambda t: self.get_timestep_embedding(t), timesteps, dtype=tf.float32
      )
      timestep_embedding = tf.squeeze(timestep_embedding, 1)
      model_pred = self.diffusion_model(
          [noisy_latents, timestep_embedding, encoded_text], training=True
      )
      loss = self.compiled_loss(target, model_pred)
      if self.use_mixed_precision:
        loss = self.optimizer.get_scaled_loss(loss)

    # Update parameters of the diffusion model.
    trainable_vars = self.diffusion_model.trainable_variables
    gradients = tape.gradient(loss, trainable_vars)
    if self.use_mixed_precision:
      gradients = self.optimizer.get_unscaled_gradients(gradients)
    gradients = [tf.clip_by_norm(g, self.max_grad_norm) for g in gradients]
    self.optimizer.apply_gradients(zip(gradients, trainable_vars))

    return {m.name: m.result() for m in self.metrics}