def main()

in src/open-r1-multimodal/src/open_r1/grpo_rec.py [0:0]


def main(script_args, training_args, model_args):
    training_args.max_completion_length = script_args.max_output_token_length
    reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]

    # Load the dataset
    dataset = LazySupervisedDataset(script_args.dataset_name, script_args)

    # 创建平衡采样器
    image_indices, text_indices = dataset.get_indices()
    batch_sampler = BalancedBatchSampler(
        image_indices=image_indices,
        text_indices=text_indices,
        batch_size=training_args.per_device_train_batch_size
    )

    # 创建标准DataLoader
    dataloader = DataLoader(
        dataset,
        batch_sampler=batch_sampler,
        collate_fn=collate_fn,  # 需自定义collate函数
        num_workers=training_args.dataloader_num_workers
    )

    trainer_cls = Qwen2VLGRPOTrainer
    # Initialize the GRPO trainer
    trainer = trainer_cls(
        model=model_args.model_name_or_path,
        reward_funcs=reward_funcs,
        args=training_args,
        data_collator=collate_fn,
        batch_sampler=batch_sampler,
        train_dataset=dataset,
        eval_dataset=None,
        peft_config=get_peft_config(model_args),
        attn_implementation=model_args.attn_implementation,
        max_pixels=script_args.max_pixels,
        min_pixels=script_args.min_pixels,
        torch_dtype=model_args.torch_dtype,
    )

    # Train and push the model to the Hub
    trainer.train()

    # Save and push to hub
    trainer.save_model(training_args.output_dir)
    if training_args.push_to_hub:
        trainer.push_to_hub(dataset_name=script_args.dataset_name)