in src/open-r1-multimodal/src/open_r1/grpo.py [0:0]
def main(script_args, training_args, model_args):
# Get reward functions
reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
print("reward_funcs:", reward_funcs)
# Load the dataset
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
# Format into conversation
def make_conversation(example):
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": example["problem"]},
],
}
# def make_conversation_image(example):
# return {
# "prompt": [
# {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
# {
# "role": "user",
# "content": [
# {"type": "image"},
# {"type": "text", "text": example["problem"]},
# ],
# },
# ],
# }
QUESTION_TEMPLATE = "{Question} Output the thinking process in <think> </think> and final answer (number) in <answer> </answer> tags."
def make_conversation_image(example):
return {
"prompt": [
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
],
},
],
}
if "image" in dataset[script_args.dataset_train_split].features:
print("has image in dataset")
dataset = dataset.map(make_conversation_image) # Utilize multiprocessing for faster mapping
# dataset = dataset.remove_columns(["original_question", "original_answer"])
else:
print("no image in dataset")
dataset = dataset.map(make_conversation)
dataset = dataset.remove_columns("messages")
trainer_cls = Qwen2VLGRPOTrainer
# Initialize the GRPO trainer
trainer = trainer_cls(
model=model_args.model_name_or_path,
reward_funcs=reward_funcs,
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
peft_config=get_peft_config(model_args),
attn_implementation=model_args.attn_implementation,
max_pixels=script_args.max_pixels,
min_pixels=script_args.min_pixels,
torch_dtype=model_args.torch_dtype,
)
# Train and push the model to the Hub
trainer.train()
# Save and push to hub
trainer.save_model(training_args.output_dir)
if training_args.push_to_hub:
trainer.push_to_hub(dataset_name=script_args.dataset_name)