def filter_options()

in expanded_checklist/checklist/text_generation.py [0:0]


    def filter_options(self, texts, word, options, threshold=5):
        # print(options)
        if type(texts) != list:
            texts = [texts]
        options = options + [word]
        in_all = set(options)
        orig_ret = []
        for text in texts:
            masked = re.sub(r'\b%s\b' % re.escape(word), self.tokenizer.mask_token, text)
            if masked == text:
                continue
            ret =  self.unmask(masked, beam_size=100, candidates=options)
            non_word = [x for x in ret if np.all([y not in [self.tokenizer.unk_token, word] for y in x[0]])]
            score = [x for x in ret if np.all([y in [word, self.tokenizer.unk_token] for y in x[0]])][0][-1]
            new_ret = [(x[0], x[1], score - x[2]) for x in non_word if score - x[2] < threshold]
            # print(text)
            # print(new_ret)
            # print()
            if text == texts[0]:
                orig_ret = new_ret
            in_all = in_all.intersection(set([x[0][0] for x in new_ret]))
        return [x for x in orig_ret if x[0][0] in in_all]