in src/open-r1-multimodal/src/open_r1/grpo_rec.py [0:0]
def main(script_args, training_args, model_args):
training_args.max_completion_length = script_args.max_output_token_length
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
# Load the dataset
dataset = LazySupervisedDataset(script_args.dataset_name, script_args)
# 创建平衡采样器
image_indices, text_indices = dataset.get_indices()
batch_sampler = BalancedBatchSampler(
image_indices=image_indices,
text_indices=text_indices,
batch_size=training_args.per_device_train_batch_size
)
# 创建标准DataLoader
dataloader = DataLoader(
dataset,
batch_sampler=batch_sampler,
collate_fn=collate_fn, # 需自定义collate函数
num_workers=training_args.dataloader_num_workers
)
trainer_cls = Qwen2VLGRPOTrainer
# Initialize the GRPO trainer
trainer = trainer_cls(
model=model_args.model_name_or_path,
reward_funcs=reward_funcs,
args=training_args,
data_collator=collate_fn,
batch_sampler=batch_sampler,
train_dataset=dataset,
eval_dataset=None,
peft_config=get_peft_config(model_args),
attn_implementation=model_args.attn_implementation,
max_pixels=script_args.max_pixels,
min_pixels=script_args.min_pixels,
torch_dtype=model_args.torch_dtype,
)
# Train and push the model to the Hub
trainer.train()
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)