in optimum/habana/trl/trainer/ddpo_trainer.py [0:0]
def _train_batched_samples(self, inner_epoch, epoch, global_step, batched_samples):
"""
Adapted from https://github.com/huggingface/trl/blob/v0.7.8/trl/trainer/ddpo_trainer.py#L508
- Reduce recompilations by avoiding constant variables in loops
- Add `mark_step()` to support lazy mode
"""
info = defaultdict(list)
for _i, sample in enumerate(batched_samples):
if self.config.train_cfg:
# concat negative prompts to sample prompts to avoid two forward passes
embeds = torch.cat([sample["negative_prompt_embeds"], sample["prompt_embeds"]])
else:
embeds = sample["prompt_embeds"]
latents = sample["latents"]
timesteps = sample["timesteps"]
next_latents = sample["next_latents"]
log_probs = sample["log_probs"]
for j in range(self.num_train_timesteps): # , desc=f"Epoch{i}"):
with self.accelerator.accumulate(self.sd_pipeline.unet):
# Reduce recompilations by avoiding constant variables in loops
latent = latents[:, 0]
timestep = timesteps[:, 0]
next_latent = next_latents[:, 0]
log_prob = log_probs[:, 0]
latents = torch.roll(latents, shifts=-1, dims=1)
timesteps = torch.roll(timesteps, shifts=-1, dims=1)
next_latents = torch.roll(next_latents, shifts=-1, dims=1)
log_probs = torch.roll(log_probs, shifts=-1, dims=1)
loss, approx_kl, clipfrac = self.calculate_loss(
latent,
timestep,
next_latent,
log_prob,
sample["advantages"],
embeds,
)
info["approx_kl"].append(approx_kl)
info["clipfrac"].append(clipfrac)
info["loss"].append(loss)
self.accelerator.backward(loss)
if self.use_habana:
self.htcore.mark_step()
if self.accelerator.sync_gradients:
trainable_layers = (
self.trainable_layers.parameters()
if not isinstance(self.trainable_layers, list)
else self.trainable_layers
)
if self.gaudi_config.use_fused_clip_norm:
self.FusedNorm.clip_norm(trainable_layers)
else:
self.self.accelerator.clip_grad_norm_(
trainable_layers,
self.config.train_max_grad_norm,
)
self.optimizer.step()
self.optimizer.zero_grad()
if self.use_habana:
self.htcore.mark_step()
# Checks if the accelerator has performed an optimization step behind the scenes
if self.accelerator.sync_gradients:
# log training-related stuff
info = {k: torch.mean(torch.stack(v)) for k, v in info.items()}
info = self.accelerator.reduce(info, reduction="mean")
info.update({"epoch": epoch, "inner_epoch": inner_epoch})
self.accelerator.log(info, step=global_step)
global_step += 1
info = defaultdict(list)
return global_step