in torchbenchmark/models/speech_transformer/speech_transformer/transformer/decoder.py [0:0]
def recognize_beam(self, encoder_outputs, char_list, args):
"""Beam search, decode one utterence now.
Args:
encoder_outputs: T x H
char_list: list of character
args: args.beam
Returns:
nbest_hyps:
"""
# search params
beam = args.beam_size
nbest = args.nbest
if args.decode_max_len == 0:
maxlen = encoder_outputs.size(0)
else:
maxlen = args.decode_max_len
encoder_outputs = encoder_outputs.unsqueeze(0)
# prepare sos
ys = torch.ones(1, 1).fill_(self.sos_id).type_as(encoder_outputs).long()
# yseq: 1xT
hyp = {'score': 0.0, 'yseq': ys}
hyps = [hyp]
ended_hyps = []
for i in range(maxlen):
hyps_best_kept = []
for hyp in hyps:
ys = hyp['yseq'] # 1 x i
# -- Prepare masks
non_pad_mask = torch.ones_like(ys).float().unsqueeze(-1) # 1xix1
slf_attn_mask = get_subsequent_mask(ys)
# -- Forward
dec_output = self.dropout(
self.tgt_word_emb(ys) * self.x_logit_scale +
self.positional_encoding(ys))
for dec_layer in self.layer_stack:
dec_output, _, _ = dec_layer(
dec_output, encoder_outputs,
non_pad_mask=non_pad_mask,
slf_attn_mask=slf_attn_mask,
dec_enc_attn_mask=None)
seq_logit = self.tgt_word_prj(dec_output[:, -1])
local_scores = F.log_softmax(seq_logit, dim=1)
# topk scores
local_best_scores, local_best_ids = torch.topk(
local_scores, beam, dim=1)
for j in range(beam):
new_hyp = {}
new_hyp['score'] = hyp['score'] + local_best_scores[0, j]
new_hyp['yseq'] = torch.ones(1, (1+ys.size(1))).type_as(encoder_outputs).long()
new_hyp['yseq'][:, :ys.size(1)] = hyp['yseq']
new_hyp['yseq'][:, ys.size(1)] = int(local_best_ids[0, j])
# will be (2 x beam) hyps at most
hyps_best_kept.append(new_hyp)
hyps_best_kept = sorted(hyps_best_kept,
key=lambda x: x['score'],
reverse=True)[:beam]
# end for hyp in hyps
hyps = hyps_best_kept
# add eos in the final loop to avoid that there are no ended hyps
if i == maxlen - 1:
for hyp in hyps:
hyp['yseq'] = torch.cat([hyp['yseq'],
torch.ones(1, 1).fill_(self.eos_id).type_as(encoder_outputs).long()], dim=1)
# add ended hypothes to a final list, and removed them from current hypothes
# (this will be a probmlem, number of hyps < beam)
remained_hyps = []
for hyp in hyps:
if hyp['yseq'][0, -1] == self.eos_id:
ended_hyps.append(hyp)
else:
remained_hyps.append(hyp)
hyps = remained_hyps
if len(hyps) > 0:
# print('remeined hypothes: ' + str(len(hyps)))
pass
else:
print('no hypothesis. Finish decoding.')
break
# for hyp in hyps:
# print('hypo: ' + ''.join([char_list[int(x)]
# for x in hyp['yseq'][0, 1:]]))
# end for i in range(maxlen)
nbest_hyps = sorted(ended_hyps, key=lambda x: x['score'], reverse=True)[
:min(len(ended_hyps), nbest)]
# compitable with LAS implementation
for hyp in nbest_hyps:
hyp['yseq'] = hyp['yseq'][0].cpu().numpy().tolist()
return nbest_hyps