in optimum/habana/trl/trainer/ddpo_trainer.py [0:0]
def step(self, epoch: int, global_step: int):
"""
Adapted from https://github.com/huggingface/trl/blob/v0.7.8/trl/trainer/ddpo_trainer.py#L234
- Add progress bar to track training epochs
- Convert bfloat to float when creating to numpy arrays
"""
samples, prompt_image_data = self._generate_samples(
iterations=self.config.sample_num_batches_per_epoch,
batch_size=self.config.sample_batch_size,
)
# collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
rewards, rewards_metadata = self.compute_rewards(
prompt_image_data, is_async=self.config.async_reward_computation
)
for i, image_data in enumerate(prompt_image_data):
image_data.extend([rewards[i], rewards_metadata[i]])
if self.image_samples_callback is not None:
self.image_samples_callback(prompt_image_data, global_step, self.accelerator.trackers[0])
rewards = torch.cat(rewards)
if rewards.dtype == torch.bfloat16:
rewards = rewards.float() # bf16 not supported by numpy
rewards = self.accelerator.gather(rewards).cpu().numpy()
self.accelerator.log(
{
"reward": rewards,
"epoch": epoch,
"reward_mean": rewards.mean(),
"reward_std": rewards.std(),
},
step=global_step,
)
if self.config.per_prompt_stat_tracking:
# gather the prompts across processes
prompt_ids = self.accelerator.gather(samples["prompt_ids"]).cpu().numpy()
prompts = self.sd_pipeline.tokenizer.batch_decode(prompt_ids, skip_special_tokens=True)
advantages = self.stat_tracker.update(prompts, rewards)
else:
advantages = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
# ungather advantages; keep the entries corresponding to the samples on this process
samples["advantages"] = (
torch.as_tensor(advantages)
.reshape(self.accelerator.num_processes, -1)[self.accelerator.process_index]
.to(self.accelerator.device)
)
del samples["prompt_ids"]
total_batch_size, num_timesteps = samples["timesteps"].shape
pbar = tqdm(
range(self.config.train_num_inner_epochs),
desc=f"Epoch {epoch}",
disable=not self.accelerator.is_main_process,
)
for inner_epoch in pbar:
# shuffle samples along batch dimension
perm = torch.randperm(total_batch_size, device=self.accelerator.device)
samples = {k: v[perm] for k, v in samples.items()}
# shuffle along time dimension independently for each sample
# still trying to understand the code below
perms = torch.stack(
[torch.randperm(num_timesteps, device=self.accelerator.device) for _ in range(total_batch_size)]
)
for key in ["timesteps", "latents", "next_latents", "log_probs"]:
samples[key] = samples[key][
torch.arange(total_batch_size, device=self.accelerator.device)[:, None],
perms,
]
original_keys = samples.keys()
original_values = samples.values()
# rebatch them as user defined train_batch_size is different from sample_batch_size
reshaped_values = [v.reshape(-1, self.config.train_batch_size, *v.shape[1:]) for v in original_values]
# Transpose the list of original values
transposed_values = zip(*reshaped_values)
# Create new dictionaries for each row of transposed values
samples_batched = [dict(zip(original_keys, row_values)) for row_values in transposed_values]
self.sd_pipeline.unet.train()
global_step = self._train_batched_samples(inner_epoch, epoch, global_step, samples_batched)
# ensure optimization step at the end of the inner epoch
if not self.accelerator.sync_gradients:
raise ValueError(
"Optimization step should have been performed by this point. Please check calculated gradient accumulation settings."
)
if epoch != 0 and epoch % self.config.save_freq == 0 and self.accelerator.is_main_process:
self.accelerator.save_state()
return global_step