in expanded_checklist/checklist/text_generation.py [0:0]
def unmask_multiple(self, texts, beam_size=500, candidates=None, metric='avg', **kwargs):
rets = []
for text in texts:
rets.append(self.unmask(text, beam_size, candidates))
scores = collections.defaultdict(lambda: 0.) if metric == 'avg' else collections.defaultdict(lambda: 999999999)
count = collections.defaultdict(lambda: 0.)
examples = {}
longest = max([len(x[0][0]) for x in rets])
rets = sorted(rets, key=lambda x:len(x[0][0]), reverse=True)
for r in rets:
for x in r:
tup = tuple(x[0])
if len(tup) != longest:
tups = [k for k in scores if tuple(k[:len(tup)]) == tup]
else:
tups = [tup]
for tup in tups:
count[tup] += 1
examples[tup] = x[1]
if metric == 'avg':
scores[tup] += x[-1]
elif metric == 'min':
scores[tup] = min(scores[tup], x[-1])
if metric == 'min':
for x in count:
# print(x, count[x])
if count[x] != len(texts):
scores[x] = -999999
else:
for x in scores:
scores[x] = scores[x] / len(texts)
scores = sorted(scores.items(), key=lambda x:x[1], reverse=True)
return [(list(x[0]), examples[x[0]], x[1]) for x in scores]