in vision/m4/training/trainer.py [0:0]
def setup_batch_size_related_configs(self):
"""
batch_size-related configs are processed here.
All this work is done here because it requires knowing the value of num_processes
"""
hparams = self.hparams
if hparams.global_batch_size_ramp_up.start is not None:
# case 1. global batch size ramp up
# 1a. ramp up constraints
if (
hparams.global_batch_size_ramp_up.finish is None
or hparams.global_batch_size_ramp_up.increment is None
or hparams.global_batch_size_ramp_up.samples is None
):
raise ValueError(
"When using batch size ramp up hparam config entries global_batch_size_ramp_up.start,"
" global_batch_size_ramp_up.finish, global_batch_size_ramp_up.increment and"
" global_batch_size_ramp_up.samples have to be defined."
)
# range checks
ramp_up_range = hparams.global_batch_size_ramp_up.finish - hparams.global_batch_size_ramp_up.start
if ramp_up_range < hparams.global_batch_size_ramp_up.increment:
raise ValueError(
f"{hparams.global_batch_size_ramp_up.start=} has to be smaller than"
f" {hparams.global_batch_size_ramp_up.finish=}."
)
if not (ramp_up_range / hparams.global_batch_size_ramp_up.increment).is_integer():
raise ValueError(
f"({hparams.global_batch_size_ramp_up.finish=} -"
f" {hparams.global_batch_size_ramp_up.start=}) /"
f" {hparams.global_batch_size_ramp_up.increment=} has to be a whole number"
)
if not (
hparams.global_batch_size_ramp_up.increment
/ (hparams.batch_size_per_gpu * self.accelerator.num_processes)
).is_integer():
raise ValueError(
f"{hparams.global_batch_size_ramp_up.increment=} has to be a multiple of"
f" {hparams.batch_size_per_gpu * self.accelerator.num_processes=}"
)
if self.accelerator.is_main_process:
logger.info(
"Will ramp up global batch size from"
f" {hparams.global_batch_size_ramp_up.start} to"
f" {hparams.global_batch_size_ramp_up.finish}, in increments of"
f" {hparams.global_batch_size_ramp_up.increment} every"
f" {hparams.global_batch_size_ramp_up.samples} samples."
)
# 1b. hparam.grad_acc_size constraints and derivations
if hparams.grad_acc_size > 1:
raise ValueError("When using batch size ramp up hparam.grad_acc_size must be None or 1.")
hparams.grad_acc_size = hparams.global_batch_size_ramp_up.start / (
hparams.batch_size_per_gpu * self.accelerator.num_processes
)
if not hparams.grad_acc_size.is_integer():
raise ValueError(
f"{hparams.global_batch_size_ramp_up.start=} has to be a multiple of"
f" {hparams.batch_size_per_gpu * self.accelerator.num_processes}"
)
hparams.grad_acc_size = int(hparams.grad_acc_size)
logger.info(f"Derived {hparams.grad_acc_size=}")
# 1c. hparams.global_batch_size constraints and derivation
if hparams.global_batch_size is not None:
raise ValueError("When using batch size ramp up hparam.global_batch_size must be None.")
# in the first run we start at global_batch_size == global_batch_size_ramp_up.start
hparams.global_batch_size = hparams.global_batch_size_ramp_up.start
else:
# case 2. fixed global batch size or fixed grad_acc_size
# 2a. constraints
if hparams.grad_acc_size > 1:
# when global_batch_size is used grad_acc_size will be derived automatically from global_batch_size and n_gpus
if hparams.global_batch_size is not None and hparams.global_batch_size > 1:
raise ValueError("set either hparams.grad_acc_size>1 or hparams.global_batch_size>1, but not both")
if hparams.global_batch_size is not None and hparams.global_batch_size > 1:
# 2b. have global_batch_size need to derive grad_acc_size
hparams.grad_acc_size = hparams.global_batch_size / (
hparams.batch_size_per_gpu * self.accelerator.num_processes
)
if not hparams.grad_acc_size.is_integer():
raise ValueError(
f"The derived {hparams.grad_acc_size=} is not an integer,"
f" {hparams.global_batch_size=} / ({hparams.batch_size_per_gpu=} *"
f" {self.accelerator.num_processes=})"
)
hparams.grad_acc_size = int(hparams.grad_acc_size)
logger.info(f"Derived {hparams.grad_acc_size=}")
else:
# 2c. have grad_acc_size need to derive global_batch_size
hparams.global_batch_size = (
hparams.batch_size_per_gpu * hparams.grad_acc_size * self.accelerator.num_processes
)
logger.info(f"Derived {hparams.global_batch_size=}")