in src/sal/search/diverse_verifier_tree_search.py [0:0]
def dvts(examples, config: Config, llm: LLM, prm: PRM):
problems = examples["problem"]
beam_results = _dvts(problems, config, llm, prm)
# group together alike beams and store in the dataset
grouped_results = defaultdict(list)
for results in beam_results:
grouped_results[results.prompt].append(results)
results = {"completions": [], "pred": [], "completion_tokens": [], "scores": []}
for p in problems:
beams = grouped_results[p]
results["completions"].append([b.current_text for b in beams])
results["pred"].append(
beams[
np.argmax(
[
aggregate_scores(b.best_scores, config.agg_strategy)
for b in beams
]
)
].current_text
)
results["scores"].append([b.best_scores for b in beams])
results["completion_tokens"].append(-1)
# TODO: construct and store the tree
return results