def configure_model_roles()

in yourbench/main.py [0:0]


def configure_model_roles(models: list[dict]) -> dict:
    """Configure which models to use for each pipeline stage."""
    if not models:
        return {}

    if len(models) == 1:
        # Single model - use for everything except chunking
        model_name = models[0]["model_name"]
        return {
            "ingestion": [model_name],
            "summarization": [model_name],
            "single_shot_question_generation": [model_name],
            "multi_hop_question_generation": [model_name],
        }

    console.print("\n[bold cyan]Model Role Assignment[/bold cyan]")
    console.print("Assign models to pipeline stages:")

    # Show available models
    table = Table(title="Available Models")
    table.add_column("Index", style="cyan")
    table.add_column("Model", style="green")
    for i, model in enumerate(models, 1):
        table.add_row(str(i), model["model_name"])
    console.print(table)

    roles = {}
    stages = [
        ("ingestion", "Document parsing & conversion"),
        ("summarization", "Document summarization"),
        ("single_shot_question_generation", "Single-hop questions"),
        ("multi_hop_question_generation", "Multi-hop questions"),
    ]

    for stage, desc in stages:
        console.print(f"\n[yellow]{stage}[/yellow]: {desc}")
        indices = Prompt.ask("Model indices (comma-separated, e.g., 1,2)", default="1")
        selected = []
        for idx in indices.split(","):
            try:
                i = int(idx.strip()) - 1
                if 0 <= i < len(models):
                    selected.append(models[i]["model_name"])
                else:
                    logger.warning(f"Model index {idx} is out of range (1-{len(models)})")
            except ValueError:
                logger.warning(f"Invalid model index '{idx}' - expected a number")
        if selected:
            roles[stage] = selected

    return roles