def unmask()

in expanded_checklist/checklist/text_generation.py [0:0]


    def unmask(self, text_with_mask, beam_size=10, candidates=None):
        if self.url is not None:
            params = {'text': text_with_mask, 'beam_size': beam_size, 'candidates': candidates}
            r = requests.post(url='%s/unmask' % self.url, data={'params': json.dumps(params)})
            r = [tuple(x) for x in json.loads(r.text)]
            return r
        tokenizer = self.tokenizer
        model = self.model
        encoded = np.array(tokenizer.encode(self.prefix_sentence + text_with_mask, add_special_tokens=True))
        cands = []
        if candidates is not None:
            candidates = candidates + [self.space_prefix + x for x in candidates]
            cands = tokenizer.convert_tokens_to_ids(candidates)
            if self.allow_word_pieces:
                cands_with_space = list(set(cands))
            else:
                cands_with_space = list(set(cands).intersection(self.with_space_set))
        input_ids = torch.tensor(encoded)
        # toks = tokenizer.tokenize('[CLS] %s [SEP]' % string)
        current_beam= [([], 0)]
        masked = (input_ids == self.tokenizer.mask_token_id).numpy().nonzero()[0]
        # print(masked)
        while len(current_beam[0][0]) != masked.shape[0]:
            current_beam = current_beam[:beam_size]
            size = len(current_beam[0][0])
            to_pred = []
            new_beam = []
            for i, current in enumerate(current_beam):
                idxs = current[0]
                c = encoded.copy()
                c[masked[:len(idxs)]] = idxs
                to_pred.append(c)
            # print('ae')
            # print('\n'.join([tokenizer.decode(x) for x in to_pred]))
            # print()
            to_pred = torch.tensor(to_pred, device=self.device)
            with torch.no_grad():
                outputs = model(to_pred)[0]
            for i, current in enumerate(current_beam):
                prev = int(to_pred[i][masked[size] - 1])
                forbid = False
                # allow tokens that don't start with space if previous is not alphanumeric
                if not self.allow_word_pieces and prev not in self.special_chars:
                    forbid = True
                    # print('Forbid Prev, current', prev,  tokenizer.decode(to_pred[i][masked[size] - 1:masked[size]+1]))
                if candidates is not None:
                    cands_to_use = cands_with_space if forbid else cands
                    scores = [outputs[i, masked[size], j] for j in cands_to_use]
                    new = [(current[0] + [int(x[0])], float(x[1]) + current[1]) for x in zip(cands_to_use, scores)]
                else:
                    if forbid:
                        v, top_preds = torch.topk(outputs[i, masked[size], self.with_space], beam_size + 10)
                        top_preds = self.with_space[top_preds]
                    else:
                        v, top_preds = torch.topk(outputs[i, masked[size]], beam_size + 10)
                    new = [(current[0] + [int(x[0])], float(x[1]) + current[1]) for x in zip(top_preds, v)]
                new_beam.extend(new)
            current_beam = sorted(new_beam, key=lambda x:x[1], reverse=True)
        ret = []
        ret_text = []
        cop = encoded.copy()
        for idxs, score in current_beam:
            # words = tokenizer.convert_ids_to_tokens(idxs)
            words = [str(tokenizer.decode([i])).strip() for i in idxs]
            cop[masked] = idxs
            text = tokenizer.decode(cop[1 + self.prefix_len:-1])
            ret.append((words, text, score / masked.shape[0]))
        ret = sorted(ret, key=lambda x:x[2], reverse=True)
        return ret