def unmask_multiple()

in expanded_checklist/checklist/text_generation.py [0:0]


    def unmask_multiple(self, texts, beam_size=500, candidates=None, metric='avg', **kwargs):
        rets = []
        for text in texts:
            rets.append(self.unmask(text, beam_size, candidates))
        scores = collections.defaultdict(lambda: 0.) if metric == 'avg' else collections.defaultdict(lambda: 999999999)
        count = collections.defaultdict(lambda: 0.)
        examples = {}
        longest = max([len(x[0][0]) for x in rets])
        rets = sorted(rets, key=lambda x:len(x[0][0]), reverse=True)
        for r in rets:
            for x in r:
                tup = tuple(x[0])
                if len(tup) != longest:
                    tups = [k for k in scores if tuple(k[:len(tup)]) == tup]
                else:
                    tups = [tup]
                for tup in tups:
                    count[tup] += 1
                    examples[tup] = x[1]
                    if metric == 'avg':
                        scores[tup] += x[-1]
                    elif metric == 'min':
                        scores[tup] = min(scores[tup], x[-1])
        if metric == 'min':
            for x in count:
                # print(x, count[x])
                if count[x] != len(texts):
                    scores[x] = -999999
        else:
            for x in scores:
                scores[x] = scores[x] / len(texts)
        scores = sorted(scores.items(), key=lambda x:x[1], reverse=True)
        return [(list(x[0]), examples[x[0]], x[1]) for x in scores]