in expanded_checklist/checklist/text_generation.py [0:0]
def filter_options(self, texts, word, options, threshold=5):
# print(options)
if type(texts) != list:
texts = [texts]
options = options + [word]
in_all = set(options)
orig_ret = []
for text in texts:
masked = re.sub(r'\b%s\b' % re.escape(word), self.tokenizer.mask_token, text)
if masked == text:
continue
ret = self.unmask(masked, beam_size=100, candidates=options)
non_word = [x for x in ret if np.all([y not in [self.tokenizer.unk_token, word] for y in x[0]])]
score = [x for x in ret if np.all([y in [word, self.tokenizer.unk_token] for y in x[0]])][0][-1]
new_ret = [(x[0], x[1], score - x[2]) for x in non_word if score - x[2] < threshold]
# print(text)
# print(new_ret)
# print()
if text == texts[0]:
orig_ret = new_ret
in_all = in_all.intersection(set([x[0][0] for x in new_ret]))
return [x for x in orig_ret if x[0][0] in in_all]