in src/sal/search/beam_search.py [0:0]
def _beam_search(batch_of_prompts, config: Config, llm: LLM, prm: PRM) -> list[Beam]:
sampling_params = SamplingParams(
temperature=config.temperature,
max_tokens=config.max_tokens,
top_p=config.top_p,
stop=["\n\n"],
include_stop_str_in_output=True,
n=1,
)
beams: list[Beam] = []
for prompt in batch_of_prompts:
for i in range(config.n):
beams.append(
Beam(
prompt=prompt,
index=i,
current_text="",
next_texts=None,
lookahead_texts=None,
pruned=False,
completed=False, # New flag to track completion
stop_reasons=None,
history=[],
best_scores=[],
all_scores=[],
previous_text=None,
completion_tokens=0,
)
)
completed_beams: list[Beam] = []
for i in tqdm(range(config.num_iterations), desc="Beam search iterations"):
if i == 0:
active_beams = [b for b in beams if not b.pruned]
else:
active_beams = [b for b in active_beams if not b.pruned]
# Duplicate active beams to ensure that we have config.n beams per iteration
if len(active_beams) != config.n:
repeats = (config.n // len(active_beams)) + 1
logger.debug(
f"Extending active_beams with {repeats} repetitions to reach size {config.n}"
)
extended_active_beams = [
copy.deepcopy(b) for b in (active_beams * repeats)[: config.n]
]
active_beams = extended_active_beams
if len(active_beams) != config.n:
raise ValueError(
f"Expected {config.n} active beams, but got {len(active_beams)}"
)
if i == config.num_iterations - 1:
# Last iteration, generate to EOS
sampling_params = SamplingParams(
temperature=config.temperature,
max_tokens=config.max_tokens,
top_p=config.top_p,
n=1,
)
convs = [
build_conv(b.prompt, b.current_text, config.system_prompt)
for b in active_beams
]
continue_final_message = i > 0
add_generation_prompt = i == 0
tokenizer = llm.get_tokenizer()
if config.custom_chat_template is not None:
tokenizer.chat_template = config.custom_chat_template
templated_convs = tokenizer.apply_chat_template(
convs,
add_generation_prompt=add_generation_prompt,
continue_final_message=continue_final_message,
tokenize=False,
)
lookahead = 0 if i == config.num_iterations - 1 else config.lookahead
gen_results = generate_k_steps(
templated_convs, lookahead, llm, sampling_params, 1
)
prompts, completions = [], []
for beam, gen_result in zip(active_beams, gen_results, strict=True):
beam.next_texts = gen_result.next_texts
beam.stop_reasons = gen_result.stop_reasons
beam.lookahead_texts = gen_result.lookahead_texts
beam.completion_tokens += gen_result.completion_tokens
beam.current_text += beam.next_texts[0]
beam.history.append(beam.next_texts[0])
if (
beam.stop_reasons[0] == "EOS"
or beam.stop_reasons[0] == "length"
or beam.next_texts[0] == ""
):
beam.completed = True
completed_beams.append(beam)
prompts.append(beam.prompt)
completions.append([beam.current_text])
scores = prm.score(prompts, completions)
agg_scores = [
[aggregate_scores(s, config.agg_strategy) for s in score]
for score in scores
]
for beam, score in zip(active_beams, scores, strict=True):
beam.all_scores = score[0]
# Now filter active_beams and agg_scores for beams that are completed
agg_scores = [
agg_scores[i] for i, b in enumerate(active_beams) if not b.completed
]
active_beams = [b for b in active_beams if not b.completed]
# Early stopping if all beams are completed
if len(active_beams) == 0:
break
# Filter duplicate active beams
if config.filter_duplicates:
# Create a dictionary to filter duplicates and retain order
unique_beam_dict = {}
for i, b in enumerate(active_beams):
if b.current_text not in unique_beam_dict:
unique_beam_dict[b.current_text] = (
i # Map the unique text to its index
)
active_beams = [active_beams[i] for i in unique_beam_dict.values()]
agg_scores = [agg_scores[i] for i in unique_beam_dict.values()]
# Get indices for top (config.n / config.beam_width) completions
top_indices = np.argsort(np.array(agg_scores).flatten())[
-(config.n // config.beam_width) :
]
for idx, beam in enumerate(active_beams):
if idx not in top_indices:
beam.pruned = True
# Filter completed beams for those with top config.n scores
if config.sort_completed:
completed_beams = sorted(
completed_beams,
key=lambda b: aggregate_scores(b.all_scores, config.agg_strategy),
reverse=True,
)[: config.n]
else:
completed_beams = completed_beams[: config.n]
if len(completed_beams) != config.n:
# If we don't have enough completed_beams, duplicate until we reach config.n
repeats = (config.n // len(completed_beams)) + 1
logger.debug(
f"Extending completed_beams with {repeats} repetitions to reach size {config.n}"
)
extended_completed_beams = [
copy.deepcopy(b) for b in (completed_beams * repeats)[: config.n]
]
completed_beams = extended_completed_beams
return completed_beams