in expanded_checklist/checklist/text_generation.py [0:0]
def unmask(self, text_with_mask, beam_size=10, candidates=None):
if self.url is not None:
params = {'text': text_with_mask, 'beam_size': beam_size, 'candidates': candidates}
r = requests.post(url='%s/unmask' % self.url, data={'params': json.dumps(params)})
r = [tuple(x) for x in json.loads(r.text)]
return r
tokenizer = self.tokenizer
model = self.model
encoded = np.array(tokenizer.encode(self.prefix_sentence + text_with_mask, add_special_tokens=True))
cands = []
if candidates is not None:
candidates = candidates + [self.space_prefix + x for x in candidates]
cands = tokenizer.convert_tokens_to_ids(candidates)
if self.allow_word_pieces:
cands_with_space = list(set(cands))
else:
cands_with_space = list(set(cands).intersection(self.with_space_set))
input_ids = torch.tensor(encoded)
# toks = tokenizer.tokenize('[CLS] %s [SEP]' % string)
current_beam= [([], 0)]
masked = (input_ids == self.tokenizer.mask_token_id).numpy().nonzero()[0]
# print(masked)
while len(current_beam[0][0]) != masked.shape[0]:
current_beam = current_beam[:beam_size]
size = len(current_beam[0][0])
to_pred = []
new_beam = []
for i, current in enumerate(current_beam):
idxs = current[0]
c = encoded.copy()
c[masked[:len(idxs)]] = idxs
to_pred.append(c)
# print('ae')
# print('\n'.join([tokenizer.decode(x) for x in to_pred]))
# print()
to_pred = torch.tensor(to_pred, device=self.device)
with torch.no_grad():
outputs = model(to_pred)[0]
for i, current in enumerate(current_beam):
prev = int(to_pred[i][masked[size] - 1])
forbid = False
# allow tokens that don't start with space if previous is not alphanumeric
if not self.allow_word_pieces and prev not in self.special_chars:
forbid = True
# print('Forbid Prev, current', prev, tokenizer.decode(to_pred[i][masked[size] - 1:masked[size]+1]))
if candidates is not None:
cands_to_use = cands_with_space if forbid else cands
scores = [outputs[i, masked[size], j] for j in cands_to_use]
new = [(current[0] + [int(x[0])], float(x[1]) + current[1]) for x in zip(cands_to_use, scores)]
else:
if forbid:
v, top_preds = torch.topk(outputs[i, masked[size], self.with_space], beam_size + 10)
top_preds = self.with_space[top_preds]
else:
v, top_preds = torch.topk(outputs[i, masked[size]], beam_size + 10)
new = [(current[0] + [int(x[0])], float(x[1]) + current[1]) for x in zip(top_preds, v)]
new_beam.extend(new)
current_beam = sorted(new_beam, key=lambda x:x[1], reverse=True)
ret = []
ret_text = []
cop = encoded.copy()
for idxs, score in current_beam:
# words = tokenizer.convert_ids_to_tokens(idxs)
words = [str(tokenizer.decode([i])).strip() for i in idxs]
cop[masked] = idxs
text = tokenizer.decode(cop[1 + self.prefix_len:-1])
ret.append((words, text, score / masked.shape[0]))
ret = sorted(ret, key=lambda x:x[2], reverse=True)
return ret