in optimum/habana/transformers/generation/utils.py [0:0]
def finalize_beams(initial_ids, beam_trace, model_config, length_penalty):
beam_trace_idx, beam_trace_scores, beam_trace_indices, beam_trace_tokens = beam_trace
bs = initial_ids.shape[0]
num_beams = beam_trace_scores.shape[1] // (num_selection * bs)
beam_trace_idx = beam_trace_idx.item()
beam_trace_scores = beam_trace_scores[:beam_trace_idx, :]
beam_trace_indices = beam_trace_indices[:beam_trace_idx, :]
beam_trace_tokens = beam_trace_tokens[:beam_trace_idx, :]
# (score, parent_beam, token_id, is_finished)
root = (float("-inf"), None, None, False)
def resolve_beam(beam):
rest = []
while beam != root:
score, prev, tok, is_finished = beam
rest.append(tok)
beam = prev
rest.reverse()
return rest
prev_beams = [[root] * num_beams] * bs
best = [[] for _ in range(bs)]
def beam_score(beam):
return (beam[3], beam[0])
for step, (scores, indices, tokens) in enumerate(
zip(beam_trace_scores, beam_trace_indices, beam_trace_tokens)
):
cur_beams = [[] for _ in range(bs)]
for idx, (s, i, t) in enumerate(zip(scores, indices, tokens)):
batch = idx // (num_beams * num_selection)
idx = idx % (num_beams * num_selection)
b_len = 1 + step
b_score = s.item() / (b_len**length_penalty)
b_tok = t.item()
is_finished = b_tok == model_config.eos_token_id
if len(cur_beams[batch]) >= num_beams:
continue
beam = (b_score, prev_beams[batch][i], b_tok, is_finished)
if not is_finished:
cur_beams[batch].append(beam)
if is_finished or (step + 1 == beam_trace_idx):
if len(best[batch]) < num_beams:
best[batch].append(beam)
best[batch] = sorted(best[batch], key=lambda x: beam_score(x))
elif beam_score(best[batch][0]) < beam_score(beam):
best[batch][0] = beam
best[batch] = sorted(best[batch], key=lambda x: beam_score(x))
prev_beams = cur_beams
def expand_if_needed(tensor, new_size, value, dim=-1):
orig_len = tensor.shape[dim]
padding_len = new_size - orig_len
import torch.nn.functional as F
if padding_len > 0:
if dim == -1:
return F.pad(tensor, (0, padding_len), value=value)
elif dim == -2:
return F.pad(tensor, (0, 0, 0, padding_len), value=value)
else:
assert False, f"Unsupported dim value: {dim}"
return tensor
results = []
for i, beam_hyp in enumerate(best):
sorted_hyps = sorted(beam_hyp, key=lambda x: beam_score(x))
res = []
for j in range(beam_scorer.num_beam_hyps_to_keep):
best_hyp_tuple = sorted_hyps.pop()
resolve = resolve_beam(best_hyp_tuple)
res.append(torch.cat((initial_ids[i], torch.tensor(resolve))))
results.append(res)
max_length = max([n.shape[-1] for m in results for n in m])
return_res = []
for i, res in enumerate(results):
for j in range(beam_scorer.num_beam_hyps_to_keep):
return_res.append(expand_if_needed(res[j], max_length, model_config.pad_token_id))
input_ids = torch.stack(return_res)
return input_ids