in vision/m4/training/trainer.py [0:0]
def _do_batch(self, batch, curr_opt_step, dataset_name=None, dataset_idx=None, validation=False):
# Use effective max_num_images per batch. ie: if the max_num_images of this batch is 3, the pixel_values and image mask are truncated accordingly.
# Same for max_height and max_width
effective_max_num_images = max(batch["num_images"])
if effective_max_num_images > 0:
images_heights = batch["pixel_attention_mask"][:, :, :, 0].sum(dim=-1)
images_widths = batch["pixel_attention_mask"][:, :, 0].sum(dim=-1)
effective_max_height = images_heights.max().int()
effective_max_width = images_widths.max().int()
batch["pixel_values"] = batch["pixel_values"][
:, :effective_max_num_images, :, :effective_max_height, :effective_max_width
]
batch["pixel_attention_mask"] = batch["pixel_attention_mask"][
:, :effective_max_num_images, :effective_max_height, :effective_max_width
]
else:
# This case is a security check: if there are no images, then it should not appear in `batch` in the first place
batch.pop("pixel_values", None)
batch.pop("pixel_attention_mask", None)
effective_max_num_tokens = max(batch["attention_mask"].sum(dim=-1))
batch["input_ids"] = batch["input_ids"][:, :effective_max_num_tokens]
if "labels" in batch:
batch["labels"] = batch["labels"][:, :effective_max_num_tokens]
batch["attention_mask"] = batch["attention_mask"][:, :effective_max_num_tokens]
batch = accelerate.utils.operations.send_to_device(batch, self.accelerator.device)
num_images = batch["num_images"].sum()
num_image_tokens = (batch["input_ids"] == self.image_token_id).sum()
num_text_tokens = batch["num_text_tokens"].sum()
total_tokens = batch["attention_mask"].numel()
num_padding = total_tokens - batch["attention_mask"].sum()
if "pixel_values" in batch:
pixel_values_sum = batch["pixel_values"].sum()
else:
pixel_values_sum = torch.tensor(0.0, device=self.accelerator.device)
image_to_text_ratio = torch.div(num_images, num_text_tokens)
vl_output = self.vl_model(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
pixel_values=batch["pixel_values"] if "pixel_values" in batch else None,
pixel_attention_mask=batch["pixel_attention_mask"] if "pixel_attention_mask" in batch else None,
labels=batch["labels"] if "labels" in batch else batch["input_ids"],
)
per_token_loss = vl_output.loss
if validation:
return (
per_token_loss,
num_images,
num_text_tokens,
image_to_text_ratio,
num_padding,
pixel_values_sum,
)
else:
if self.hparams.loss_weights_per_dataset is not None:
per_token_loss *= self.hparams.loss_weights_per_dataset[dataset_idx]
if self.optim_param.z_loss > 0.0:
logits = vl_output.logits
attention_mask = batch["attention_mask"] * (1 - (batch["input_ids"] == self.image_token_id).long())
log_z = torch.logsumexp(logits, dim=-1) * attention_mask
z_loss = log_z**2
z_loss = z_loss.sum() / attention_mask.sum()
combined_loss = per_token_loss + self.optim_param.z_loss * z_loss
else:
z_loss = torch.tensor(0.0, device=self.accelerator.device)
combined_loss = per_token_loss
sync_gradients = self.accelerator.sync_gradients
deepspeed = hasattr(self.accelerator, "deepspeed_engine_wrapped")
# accelerate's deepspeed `backward` calls `engine.step`, which is a problem if we want
# to investigate things before step, so override with just a backward call and then call
# `engine.step` along with `optim.step` a bit lower
if deepspeed:
self.accelerator.deepspeed_engine_wrapped.engine.backward(combined_loss)
else:
self.accelerator.backward(combined_loss)
if sync_gradients:
self.accelerator.clip_grad_norm_(self.vl_model.parameters(), self.hparams.grad_clip)
if (
len(self.hparams.train_logging_grad_param_deepspeed) != 0
and (curr_opt_step + 1) % self.hparams.train_logging_grad_param_deepspeed_opt_steps == 0
):
self._log_deepspeed_training_stats(curr_opt_step=curr_opt_step)
if deepspeed:
self.accelerator.deepspeed_engine_wrapped.engine.step()
self.vl_optim.step()
# 1. sync_gradients is used for this dirty trick: https://github.com/huggingface/m4/pull/386
# 2. since we don't accelerate prepare the lr scheduler we need to manually skip it if
# optimizer skipped (otherwise accelerate would do that for us)
if sync_gradients and not self.accelerator.optimizer_step_was_skipped:
self.vl_scheduler.step()
self.vl_optim.zero_grad(set_to_none=True)
tflops_per_batch_per_gpu = self.vl_model.get_model_tflops_per_batch_per_gpu(
hparams=self.hparams,
data_param=getattr(self.data_param, dataset_name),
tokenizer=self.tokenizer,
max_num_images=effective_max_num_images,
max_num_tokens=effective_max_num_tokens,
).to(self.accelerator.device)
# Reset batch
return (
per_token_loss,
z_loss,
num_images,
num_image_tokens,
num_text_tokens,
image_to_text_ratio,
num_padding,
pixel_values_sum,
tflops_per_batch_per_gpu,
)