in project/paperbench/paperbench/rubric/tasks.py [0:0]
def reduce_to_category(node: TaskNode, category: str) -> TaskNode | None:
"""
Returns a new tree (or `None`) where any leaf node not labeled
`category` is removed. Internal nodes are kept only if
they have at least one valid child after pruning.
"""
if node.is_leaf():
if node.task_category == category:
return node
return None
filtered_sub_tasks = []
for st in node.sub_tasks:
pruned = reduce_to_category(st, category)
if pruned is not None:
filtered_sub_tasks.append(pruned)
# need this to drop trees that don't contain any `category`
if not filtered_sub_tasks and node.task_category != category:
return None
return replace(node, sub_tasks=filtered_sub_tasks)