in optimum/habana/transformers/trainer.py [0:0]
def get_batch_samples_transformers(self, epoch_iterator, num_batches, device):
batch_samples = []
num_items_in_batch = None
for _ in range(num_batches):
try:
batch_samples.append(next(epoch_iterator))
except StopIteration:
break
count_num_items_in_batch = (
len(batch_samples) > 0
and "labels" in batch_samples[0]
and (
# num_items_in_batch is passed to model forward
# https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/trainer.py#L3757
self.model_accepts_loss_kwargs
# num_items_in_batch is passed to compute_loss_func
# https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/trainer.py#L3773
or self.compute_loss_func is not None
# num_items_in_batch is also verified if (self.model_accepts_loss_kwargs or self.compute_loss_func)
# https://github.com/huggingface/transformers/blob/v4.49.0/src/transformers/trainer.py#L3790
)
)
if count_num_items_in_batch:
# For now we don't support object detection
try:
num_items_in_batch = torch.cat([batch["labels"] for batch in batch_samples]).ne(-100).sum()
except (TypeError, AttributeError, RuntimeError):
pass
if num_items_in_batch is not None:
if self.args.average_tokens_across_devices:
num_items_in_batch = self.accelerator.gather(num_items_in_batch).sum()
num_items_in_batch = num_items_in_batch.to(device)
return batch_samples, num_items_in_batch