in submission_code/tools.py [0:0]
def beam_search(src, model, pad_token, bos_id, end_token, max_len=10, k=5):
device = next(model.parameters()).device
src = src.view(1,-1).to(device)
src_mask = (src != pad_token).to(device)
memory = None
input_seq = [bos_id]
beam = [(input_seq, 0)]
for i in range(max_len):
candidates = []
candidates_proba = []
for snt, snt_proba in beam:
if snt[-1] == end_token:
candidates.append(snt)
candidates_proba.append(snt_proba)
else:
snt_tensor = torch.tensor(snt).view(1, -1).long().to(device)
if memory is None:
memory = model(
input_ids=src,
attention_mask=src_mask,
decoder_input_ids=snt_tensor,
)
else:
memory = model(
input_ids=src,
attention_mask=src_mask,
decoder_input_ids=snt_tensor,
encoder_outputs=(memory[1], memory[-1]),
)
proba = memory[0].cpu()[0,-1, :]
proba = torch.log_softmax(proba, dim=-1).numpy()
best_k = np.argpartition(-proba, k - 1)[:k]
for tok in best_k:
candidates.append(snt + [tok])
candidates_proba.append(snt_proba + proba[tok])
best_candidates = np.argpartition(-np.array(candidates_proba), k - 1)[:k]
beam = [(candidates[j], candidates_proba[j]) for j in best_candidates]
beam = sorted(beam, key=lambda x: -x[1])
return beam