in submission_code/best_ctx.py [0:0]
def predict(self, text, util, beam_width):
text = text + ' ' + self.ctx.get(util, '')
text_enc = self.text_tokenizer.encode(text)
tokens = torch.tensor([bos_id] + text_enc[:self.config.max_src_len] + [eos_id]).long()
with torch.no_grad():
pred = beam_search(tokens, self.model.tr, pad_id, bos_id, eos_id, max_len=self.config.max_tgt_len, k=beam_width)
pred = [(self.cmd_tokenizer.decode(list(map(int, x))), proba) for x, proba in pred]
return pred