def setup_batch_size_related_configs()

in vision/m4/training/trainer.py [0:0]


    def setup_batch_size_related_configs(self):
        """
        batch_size-related configs are processed here.

        All this work is done here because it requires knowing the value of num_processes
        """
        hparams = self.hparams

        if hparams.global_batch_size_ramp_up.start is not None:
            # case 1. global batch size ramp up

            # 1a. ramp up constraints
            if (
                hparams.global_batch_size_ramp_up.finish is None
                or hparams.global_batch_size_ramp_up.increment is None
                or hparams.global_batch_size_ramp_up.samples is None
            ):
                raise ValueError(
                    "When using batch size ramp up hparam config entries global_batch_size_ramp_up.start,"
                    " global_batch_size_ramp_up.finish, global_batch_size_ramp_up.increment and"
                    " global_batch_size_ramp_up.samples have to be defined."
                )

            # range checks
            ramp_up_range = hparams.global_batch_size_ramp_up.finish - hparams.global_batch_size_ramp_up.start
            if ramp_up_range < hparams.global_batch_size_ramp_up.increment:
                raise ValueError(
                    f"{hparams.global_batch_size_ramp_up.start=} has to be smaller than"
                    f" {hparams.global_batch_size_ramp_up.finish=}."
                )

            if not (ramp_up_range / hparams.global_batch_size_ramp_up.increment).is_integer():
                raise ValueError(
                    f"({hparams.global_batch_size_ramp_up.finish=} -"
                    f" {hparams.global_batch_size_ramp_up.start=}) /"
                    f" {hparams.global_batch_size_ramp_up.increment=} has to be a whole number"
                )

            if not (
                hparams.global_batch_size_ramp_up.increment
                / (hparams.batch_size_per_gpu * self.accelerator.num_processes)
            ).is_integer():
                raise ValueError(
                    f"{hparams.global_batch_size_ramp_up.increment=} has to be a multiple of"
                    f" {hparams.batch_size_per_gpu * self.accelerator.num_processes=}"
                )

            if self.accelerator.is_main_process:
                logger.info(
                    "Will ramp up global batch size from"
                    f" {hparams.global_batch_size_ramp_up.start} to"
                    f" {hparams.global_batch_size_ramp_up.finish}, in increments of"
                    f" {hparams.global_batch_size_ramp_up.increment} every"
                    f" {hparams.global_batch_size_ramp_up.samples} samples."
                )

            # 1b. hparam.grad_acc_size constraints and derivations
            if hparams.grad_acc_size > 1:
                raise ValueError("When using batch size ramp up hparam.grad_acc_size must be None or 1.")

            hparams.grad_acc_size = hparams.global_batch_size_ramp_up.start / (
                hparams.batch_size_per_gpu * self.accelerator.num_processes
            )
            if not hparams.grad_acc_size.is_integer():
                raise ValueError(
                    f"{hparams.global_batch_size_ramp_up.start=} has to be a multiple of"
                    f" {hparams.batch_size_per_gpu * self.accelerator.num_processes}"
                )
            hparams.grad_acc_size = int(hparams.grad_acc_size)
            logger.info(f"Derived {hparams.grad_acc_size=}")

            # 1c. hparams.global_batch_size constraints and derivation
            if hparams.global_batch_size is not None:
                raise ValueError("When using batch size ramp up hparam.global_batch_size must be None.")
            # in the first run we start at global_batch_size == global_batch_size_ramp_up.start
            hparams.global_batch_size = hparams.global_batch_size_ramp_up.start

        else:
            # case 2. fixed global batch size or fixed grad_acc_size

            # 2a. constraints
            if hparams.grad_acc_size > 1:
                # when global_batch_size is used grad_acc_size will be derived automatically from global_batch_size and n_gpus
                if hparams.global_batch_size is not None and hparams.global_batch_size > 1:
                    raise ValueError("set either hparams.grad_acc_size>1 or hparams.global_batch_size>1, but not both")

            if hparams.global_batch_size is not None and hparams.global_batch_size > 1:
                # 2b. have global_batch_size need to derive grad_acc_size

                hparams.grad_acc_size = hparams.global_batch_size / (
                    hparams.batch_size_per_gpu * self.accelerator.num_processes
                )
                if not hparams.grad_acc_size.is_integer():
                    raise ValueError(
                        f"The derived {hparams.grad_acc_size=} is not an integer,"
                        f" {hparams.global_batch_size=} / ({hparams.batch_size_per_gpu=} *"
                        f" {self.accelerator.num_processes=})"
                    )
                hparams.grad_acc_size = int(hparams.grad_acc_size)
                logger.info(f"Derived {hparams.grad_acc_size=}")
            else:
                # 2c. have grad_acc_size need to derive global_batch_size

                hparams.global_batch_size = (
                    hparams.batch_size_per_gpu * hparams.grad_acc_size * self.accelerator.num_processes
                )
                logger.info(f"Derived {hparams.global_batch_size=}")