def _check_dataset_can_fill_batch()

in optimum/graphcore/trainer.py [0:0]


    def _check_dataset_can_fill_batch(self, dataset: torch.utils.data.Dataset, for_inference: bool = False) -> None:
        replication_factor = (
            self.ipu_config.inference_replication_factor if for_inference else self.ipu_config.replication_factor
        )
        gradient_accumulation_steps = 1 if for_inference else self.ipu_config.gradient_accumulation_steps
        device_iterations = (
            self.ipu_config.inference_device_iterations if for_inference else self.ipu_config.device_iterations
        )
        micro_batch_size = (
            self.args.per_device_eval_batch_size if for_inference else self.args.per_device_train_batch_size
        )
        global_batch_size = micro_batch_size * replication_factor * gradient_accumulation_steps * device_iterations

        try:
            len(dataset)
        except Exception:
            # If the length of the dataset cannot be determined skip the checks
            return
        if len(dataset) < global_batch_size:
            mode_str = "inference_" if for_inference else ""
            logger.warning(
                f"The provided dataset is of length {len(dataset)}, but the total dataset batch size is {global_batch_size}. "
                f"This batch size is calculated as:\n"
                f"  per_device_{'eval' if for_inference else 'train'}_batch_size={micro_batch_size}\n"
                f"* {mode_str}{replication_factor=}\n"
                f"* {mode_str}{gradient_accumulation_steps=}\n"
                f"* {mode_str}{device_iterations=}\n"
                "Please disregard this warning if you believe the dataset is reporting an incorrect length, such as 1."
            )