def _train_batched_samples()

in optimum/habana/trl/trainer/ddpo_trainer.py [0:0]


    def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples):
        """
        Adapted from https://github.com/huggingface/trl/blob/v0.7.8/trl/trainer/ddpo_trainer.py#L508
        - Reduce recompilations by avoiding constant variables in loops
        - Add `mark_step()` to support lazy mode
        """
        info = defaultdict(list)

        for _i, sample in enumerate(batched_samples):
            if self.config.train_cfg:
                # concat negative prompts to sample prompts to avoid two forward passes
                embeds = torch.cat([sample["negative_prompt_embeds"], sample["prompt_embeds"]])
            else:
                embeds = sample["prompt_embeds"]

            latents = sample["latents"]
            timesteps = sample["timesteps"]
            next_latents = sample["next_latents"]
            log_probs = sample["log_probs"]

            for j in range(self.num_train_timesteps):  # , desc=f"Epoch{i}"):
                with self.accelerator.accumulate(self.sd_pipeline.unet):
                    # Reduce recompilations by avoiding constant variables in loops
                    latent = latents[:, 0]
                    timestep = timesteps[:, 0]
                    next_latent = next_latents[:, 0]
                    log_prob = log_probs[:, 0]
                    latents = torch.roll(latents, shifts=-1, dims=1)
                    timesteps = torch.roll(timesteps, shifts=-1, dims=1)
                    next_latents = torch.roll(next_latents, shifts=-1, dims=1)
                    log_probs = torch.roll(log_probs, shifts=-1, dims=1)

                    loss, approx_kl, clipfrac = self.calculate_loss(
                        latent,
                        timestep,
                        next_latent,
                        log_prob,
                        sample["advantages"],
                        embeds,
                    )

                    info["approx_kl"].append(approx_kl)
                    info["clipfrac"].append(clipfrac)
                    info["loss"].append(loss)
                    self.accelerator.backward(loss)
                    if self.use_habana:
                        self.htcore.mark_step()

                    if self.accelerator.sync_gradients:
                        trainable_layers = (
                            self.trainable_layers.parameters()
                            if not isinstance(self.trainable_layers, list)
                            else self.trainable_layers
                        )
                        if self.gaudi_config.use_fused_clip_norm:
                            self.FusedNorm.clip_norm(trainable_layers)
                        else:
                            self.self.accelerator.clip_grad_norm_(
                                trainable_layers,
                                self.config.train_max_grad_norm,
                            )
                    self.optimizer.step()
                    self.optimizer.zero_grad()
                    if self.use_habana:
                        self.htcore.mark_step()

                # Checks if the accelerator has performed an optimization step behind the scenes
                if self.accelerator.sync_gradients:
                    # log training-related stuff
                    info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
                    info = self.accelerator.reduce(info, reduction="mean")
                    info.update({"epoch": epoch, "inner_epoch": inner_epoch})
                    self.accelerator.log(info, step=global_step)
                    global_step += 1
                    info = defaultdict(list)

        return global_step