in optimum/graphcore/trainer.py [0:0]
def _check_dataset_can_fill_batch(self, dataset: torch.utils.data.Dataset, for_inference: bool = False) -> None:
replication_factor = (
self.ipu_config.inference_replication_factor if for_inference else self.ipu_config.replication_factor
)
gradient_accumulation_steps = 1 if for_inference else self.ipu_config.gradient_accumulation_steps
device_iterations = (
self.ipu_config.inference_device_iterations if for_inference else self.ipu_config.device_iterations
)
micro_batch_size = (
self.args.per_device_eval_batch_size if for_inference else self.args.per_device_train_batch_size
)
global_batch_size = micro_batch_size * replication_factor * gradient_accumulation_steps * device_iterations
try:
len(dataset)
except Exception:
# If the length of the dataset cannot be determined skip the checks
return
if len(dataset) < global_batch_size:
mode_str = "inference_" if for_inference else ""
logger.warning(
f"The provided dataset is of length {len(dataset)}, but the total dataset batch size is {global_batch_size}. "
f"This batch size is calculated as:\n"
f" per_device_{'eval' if for_inference else 'train'}_batch_size={micro_batch_size}\n"
f"* {mode_str}{replication_factor=}\n"
f"* {mode_str}{gradient_accumulation_steps=}\n"
f"* {mode_str}{device_iterations=}\n"
"Please disregard this warning if you believe the dataset is reporting an incorrect length, such as 1."
)