def finalize_beams()

in optimum/habana/transformers/generation/utils.py [0:0]


        def finalize_beams(initial_ids, beam_trace, model_config, length_penalty):
            beam_trace_idx, beam_trace_scores, beam_trace_indices, beam_trace_tokens = beam_trace
            bs = initial_ids.shape[0]
            num_beams = beam_trace_scores.shape[1] // (num_selection * bs)

            beam_trace_idx = beam_trace_idx.item()
            beam_trace_scores = beam_trace_scores[:beam_trace_idx, :]
            beam_trace_indices = beam_trace_indices[:beam_trace_idx, :]
            beam_trace_tokens = beam_trace_tokens[:beam_trace_idx, :]

            # (score, parent_beam, token_id, is_finished)
            root = (float("-inf"), None, None, False)

            def resolve_beam(beam):
                rest = []
                while beam != root:
                    score, prev, tok, is_finished = beam
                    rest.append(tok)
                    beam = prev
                rest.reverse()
                return rest

            prev_beams = [[root] * num_beams] * bs
            best = [[] for _ in range(bs)]

            def beam_score(beam):
                return (beam[3], beam[0])

            for step, (scores, indices, tokens) in enumerate(
                zip(beam_trace_scores, beam_trace_indices, beam_trace_tokens)
            ):
                cur_beams = [[] for _ in range(bs)]
                for idx, (s, i, t) in enumerate(zip(scores, indices, tokens)):
                    batch = idx // (num_beams * num_selection)
                    idx = idx % (num_beams * num_selection)
                    b_len = 1 + step
                    b_score = s.item() / (b_len**length_penalty)
                    b_tok = t.item()
                    is_finished = b_tok == model_config.eos_token_id
                    if len(cur_beams[batch]) >= num_beams:
                        continue
                    beam = (b_score, prev_beams[batch][i], b_tok, is_finished)
                    if not is_finished:
                        cur_beams[batch].append(beam)
                    if is_finished or (step + 1 == beam_trace_idx):
                        if len(best[batch]) < num_beams:
                            best[batch].append(beam)
                            best[batch] = sorted(best[batch], key=lambda x: beam_score(x))
                        elif beam_score(best[batch][0]) < beam_score(beam):
                            best[batch][0] = beam
                            best[batch] = sorted(best[batch], key=lambda x: beam_score(x))
                prev_beams = cur_beams

            def expand_if_needed(tensor, new_size, value, dim=-1):
                orig_len = tensor.shape[dim]
                padding_len = new_size - orig_len
                import torch.nn.functional as F

                if padding_len > 0:
                    if dim == -1:
                        return F.pad(tensor, (0, padding_len), value=value)
                    elif dim == -2:
                        return F.pad(tensor, (0, 0, 0, padding_len), value=value)
                    else:
                        assert False, f"Unsupported dim value: {dim}"
                return tensor

            results = []
            for i, beam_hyp in enumerate(best):
                sorted_hyps = sorted(beam_hyp, key=lambda x: beam_score(x))
                res = []
                for j in range(beam_scorer.num_beam_hyps_to_keep):
                    best_hyp_tuple = sorted_hyps.pop()
                    resolve = resolve_beam(best_hyp_tuple)
                    res.append(torch.cat((initial_ids[i], torch.tensor(resolve))))
                results.append(res)

            max_length = max([n.shape[-1] for m in results for n in m])
            return_res = []
            for i, res in enumerate(results):
                for j in range(beam_scorer.num_beam_hyps_to_keep):
                    return_res.append(expand_if_needed(res[j], max_length, model_config.pad_token_id))
            input_ids = torch.stack(return_res)
            return input_ids