course/en/chapter13/grpo_format.py (87 lines of code) (raw):
import marimo
__generated_with = "0.10.6"
app = marimo.App(width="medium")
@app.cell(hide_code=True)
def _():
import marimo as mo
mo.md(
"""
## Structure Format Reward Function
This example demonstrates a reward function that evaluates whether
completions follow a specific structure format. Either a `<think>...</think><answer>...</answer>`
or a `<code>...</code><explanation>...</explanation>` format.
Use the buttons to select which structure format to reward.
"""
)
return (mo,)
@app.cell(hide_code=True)
def _(mo):
format_buttons = mo.ui.radio(
options=["think-answer", "code-explanation"],
value="think-answer",
label="Format to reward",
)
format_buttons
return (format_buttons,)
@app.cell(hide_code=True)
def _(mo, format_buttons):
import plotly.express as px
import re
# Sample completions with different formats
completions = [
# Think-answer format examples
"<think>Let me solve this step by step</think><answer>42</answer>",
"The answer is 15 without any special format",
"<code>print('Hello world')</code><explanation>This prints a greeting</explanation>",
# Code-explanation format examples
"<code>def add(a, b): return a + b</code><explanation>A function to add numbers</explanation>",
"<code>for i in range(10): print(i)</code>",
"<think>I should use a loop</think><code>while True: pass</code>",
]
# Create shortened versions for display
short_completions = [c[:30] + "..." if len(c) > 30 else c for c in completions]
def format_reward(completions, format_type="think-answer", **kwargs):
"""
Reward completions that follow the desired format structure
Args:
completions: list of completions to evaluate
format_type: which format structure to reward
Returns:
list of rewards and details
"""
# Define patterns for different formats
patterns = {
"think-answer": r"<think>.*?</think>\s*<answer>.*?</answer>",
"code-explanation": r"<code>.*?</code>\s*<explanation>.*?</explanation>",
}
# Select the pattern based on format_type
pattern = patterns.get(format_type, patterns["think-answer"])
rewards = []
details = []
categories = []
for completion in completions:
match = re.search(pattern, completion, re.DOTALL)
if match:
# Full match for the exact format
rewards.append(1.0)
details.append(f"Correct {format_type} format")
categories.append("Exact Format Match")
elif f"<{format_type.split('-')[0]}>" in completion:
# Partial match - has the opening tag of the format
rewards.append(0.5)
details.append(f"Has {format_type.split('-')[0]} tag but incomplete")
categories.append("Partial Format Match")
elif any(f"<{tag}>" in completion for tag in format_type.split("-")):
# Contains at least one of the required tags
rewards.append(0.2)
details.append("Has some required tags but wrong format")
categories.append("Some Tags Present")
else:
# No match at all
rewards.append(0.0)
details.append("Incorrect format")
categories.append("No Format Match")
return rewards, details, categories
# Calculate rewards
rewards, details, categories = format_reward(
completions=completions, format_type=format_buttons.value
)
# Display the results
results = []
for completion, reward, detail, category in zip(
short_completions, rewards, details, categories
):
results.append(
{
"Completion": completion,
"Reward": reward,
"Detail": detail,
"Category": category,
}
)
# Create a table view
mo.md(f"### Results for {format_buttons.value} format")
mo.ui.table(results)
# Create a bar chart comparing rewards by completion
fig = px.bar(
results,
x="Completion",
y="Reward",
color="Category",
title=f"Format Rewards by Completion ({format_buttons.value})",
hover_data=["Detail"],
)
mo.ui.plotly(fig)
if __name__ == "__main__":
app.run()