def beam_search()

in eval/PER_src/seq_alignment.py [0:0]


def beam_search(score_preds, nKeep, blankLabel):

    T, P = score_preds.shape
    beams = set([''])
    pb_t_1 = {"": 1}
    pnb_t_1 = {"": 0}

    def getLastNumber(b):
        return int(b.split(',')[-1])

    for t in range(T):

        nextBeams = set()
        pb_t = {}
        pnb_t = {}
        for i_beam, b in enumerate(beams):
            if b not in pb_t:
                pb_t[b] = 0
                pnb_t[b] = 0

            if len(b) > 0:
                pnb_t[b] += pnb_t_1[b] * score_preds[t, getLastNumber(b)]
            pb_t[b] = (pnb_t_1[b] + pb_t_1[b]) * score_preds[t, blankLabel]
            nextBeams.add(b)

            for c in range(P):
                if c == blankLabel:
                    continue

                b_ = b + "," + str(c)
                if b_ not in pb_t:
                    pb_t[b_] = 0
                    pnb_t[b_] = 0

                if b != "" and getLastNumber(b) == c:
                    pnb_t[b_] += pb_t_1[b] * score_preds[t, c]
                else:
                    pnb_t[b_] += (pb_t_1[b] + pnb_t_1[b]) * score_preds[t, c]
                nextBeams.add(b_)

        allPreds = [(pb_t[b] + pnb_t[b], b) for b in nextBeams]
        allPreds.sort(reverse=True)

        beams = [x[1] for x in allPreds[:nKeep]]
        pb_t_1 = deepcopy(pb_t)
        pnb_t_1 = deepcopy(pnb_t)

    output = []
    for score, x in allPreds[:nKeep]:
        output.append((score, [int(y) for y in x.split(',') if len(y) > 0]))
    return output