def reduce_to_category()

in project/paperbench/paperbench/rubric/tasks.py [0:0]


def reduce_to_category(node: TaskNode, category: str) -> TaskNode | None:
    """
    Returns a new tree (or `None`) where any leaf node not labeled
    `category` is removed. Internal nodes are kept only if
    they have at least one valid child after pruning.
    """
    if node.is_leaf():
        if node.task_category == category:
            return node
        return None

    filtered_sub_tasks = []
    for st in node.sub_tasks:
        pruned = reduce_to_category(st, category)
        if pruned is not None:
            filtered_sub_tasks.append(pruned)

    # need this to drop trees that don't contain any `category`
    if not filtered_sub_tasks and node.task_category != category:
        return None

    return replace(node, sub_tasks=filtered_sub_tasks)