in yourbench/main.py [0:0]
def configure_model_roles(models: list[dict]) -> dict:
"""Configure which models to use for each pipeline stage."""
if not models:
return {}
if len(models) == 1:
# Single model - use for everything except chunking
model_name = models[0]["model_name"]
return {
"ingestion": [model_name],
"summarization": [model_name],
"single_shot_question_generation": [model_name],
"multi_hop_question_generation": [model_name],
}
console.print("\n[bold cyan]Model Role Assignment[/bold cyan]")
console.print("Assign models to pipeline stages:")
# Show available models
table = Table(title="Available Models")
table.add_column("Index", style="cyan")
table.add_column("Model", style="green")
for i, model in enumerate(models, 1):
table.add_row(str(i), model["model_name"])
console.print(table)
roles = {}
stages = [
("ingestion", "Document parsing & conversion"),
("summarization", "Document summarization"),
("single_shot_question_generation", "Single-hop questions"),
("multi_hop_question_generation", "Multi-hop questions"),
]
for stage, desc in stages:
console.print(f"\n[yellow]{stage}[/yellow]: {desc}")
indices = Prompt.ask("Model indices (comma-separated, e.g., 1,2)", default="1")
selected = []
for idx in indices.split(","):
try:
i = int(idx.strip()) - 1
if 0 <= i < len(models):
selected.append(models[i]["model_name"])
else:
logger.warning(f"Model index {idx} is out of range (1-{len(models)})")
except ValueError:
logger.warning(f"Invalid model index '{idx}' - expected a number")
if selected:
roles[stage] = selected
return roles