in src/sal/search/beam_search.py [0:0]
def beam_search(examples, config: Config, llm: LLM, prm: PRM):
problems = examples["problem"]
beam_results = _beam_search(problems, config, llm, prm)
# Group together alike beams and store in the dataset
grouped_results = defaultdict(list)
for results in beam_results:
grouped_results[results.prompt].append(results)
results = {"completions": [], "pred": [], "completion_tokens": [], "scores": []}
for p in problems:
beams = grouped_results[p]
completions = [b.current_text for b in beams]
agg_scores = [
aggregate_scores(b.all_scores, config.agg_strategy) for b in beams
]
pred = completions[np.argmax(agg_scores)]
results["completions"].append(completions)
results["scores"].append([b.all_scores for b in beams])
results["pred"].append(pred)
results["completion_tokens"].append([b.completion_tokens for b in beams])
return results