in eval/PER_src/seq_alignment.py [0:0]
def beam_search(score_preds, nKeep, blankLabel):
T, P = score_preds.shape
beams = set([''])
pb_t_1 = {"": 1}
pnb_t_1 = {"": 0}
def getLastNumber(b):
return int(b.split(',')[-1])
for t in range(T):
nextBeams = set()
pb_t = {}
pnb_t = {}
for i_beam, b in enumerate(beams):
if b not in pb_t:
pb_t[b] = 0
pnb_t[b] = 0
if len(b) > 0:
pnb_t[b] += pnb_t_1[b] * score_preds[t, getLastNumber(b)]
pb_t[b] = (pnb_t_1[b] + pb_t_1[b]) * score_preds[t, blankLabel]
nextBeams.add(b)
for c in range(P):
if c == blankLabel:
continue
b_ = b + "," + str(c)
if b_ not in pb_t:
pb_t[b_] = 0
pnb_t[b_] = 0
if b != "" and getLastNumber(b) == c:
pnb_t[b_] += pb_t_1[b] * score_preds[t, c]
else:
pnb_t[b_] += (pb_t_1[b] + pnb_t_1[b]) * score_preds[t, c]
nextBeams.add(b_)
allPreds = [(pb_t[b] + pnb_t[b], b) for b in nextBeams]
allPreds.sort(reverse=True)
beams = [x[1] for x in allPreds[:nKeep]]
pb_t_1 = deepcopy(pb_t)
pnb_t_1 = deepcopy(pnb_t)
output = []
for score, x in allPreds[:nKeep]:
output.append((score, [int(y) for y in x.split(',') if len(y) > 0]))
return output