in expanded_checklist/checklist/text_generation.py [0:0]
def replace_word(self, text, word, threshold=5, beam_size=100, candidates=None):
masked = re.sub(r'\b%s\b' % re.escape(word), self.tokenizer.mask_token, text)
if masked == text:
return []
if candidates is not None:
candidates = [word] + candidates
ret = self.unmask(masked, beam_size=beam_size, candidates=candidates)
non_word = [x for x in ret if np.all([y not in [self.tokenizer.unk_token, word] for y in x[0]])]
score = [x for x in ret if np.all([y in [word, self.tokenizer.unk_token] for y in x[0]])]
if not score:
score = 0
else:
score = score[0][-1]
escaped = re.escape(word)
# new_ret = [(x[0], x[1], score - x[2]) for x in non_word if score - x[2] < threshold]
try:
new_ret = [(x[0], re.sub(r'\b%s\b' % escaped, x[0][0], text), score - x[2]) for x in non_word if score - x[2] < threshold]
except:
new_ret = [(x[0], x[1], score - x[2]) for x in non_word if score - x[2] < threshold]
return new_ret