in trl/trainer/grpo_config.py [0:0]
def __post_init__(self):
self.bf16 = not (self.fp16) if self.bf16 is None else self.bf16
super().__post_init__()
num_processes = self.world_size
# The current default effective batch size
if self.generation_batch_size is not None and self.steps_per_generation is not None:
raise ValueError(
"'generation_batch_size' and 'steps_per_generation' can not be both configured at the same time"
)
if self.steps_per_generation is None:
self.steps_per_generation = self.gradient_accumulation_steps
if self.generation_batch_size is None:
self.generation_batch_size = self.per_device_train_batch_size * num_processes * self.steps_per_generation
if self.generation_batch_size % (self.per_device_train_batch_size * num_processes) != 0:
raise ValueError(
f"generation_batch_size ({self.generation_batch_size}) must be divisible by the global batch size "
f"({self.per_device_train_batch_size * num_processes})."
)
self.steps_per_generation = self.generation_batch_size // (self.per_device_train_batch_size * num_processes)
# Check if the effective batch size can be divided by the number of generations
if self.num_generations < 2:
raise ValueError(
"GRPO requires at least 2 generations per prompt to calculate the advantages. You provided "
f"{self.num_generations}, which is less than the minimum required."
)
possible_values = [
n_gen for n_gen in range(2, self.generation_batch_size + 1) if (self.generation_batch_size) % n_gen == 0
]
if self.num_generations not in possible_values:
raise ValueError(
f"The effective train batch size ({num_processes} x {self.per_device_train_batch_size} x "
f"{self.steps_per_generation}) must be evenly divisible by the number of generations per "
f"prompt ({self.num_generations}). Given the current effective train batch size, the valid values for "
f"the number of generations are: {possible_values}."
)
if self.eval_strategy != "no":
global_eval_batch_size = self.per_device_eval_batch_size * num_processes
possible_values = [
n_gen for n_gen in range(2, global_eval_batch_size + 1) if (global_eval_batch_size) % n_gen == 0
]
if self.num_generations not in possible_values:
raise ValueError(
f"The global eval batch size ({num_processes} x {self.per_device_eval_batch_size}) must be "
f"evenly divisible by the number of generations per prompt ({self.num_generations}). Given the "
"current global eval batch size, the valid values for the number of generations are: "
f"{possible_values}."
)
if self.delta is not None and self.use_liger_loss:
raise ValueError("Liger loss does not support two-sided GRPO loss yet.")