in senteval/engine.py [0:0]
def eval(self, name):
# evaluate on evaluation [name], either takes string or list of strings
if (isinstance(name, list)):
self.results = {x: self.eval(x) for x in name}
return self.results
tpath = self.params.task_path
assert name in self.list_tasks, str(name) + ' not in ' + str(self.list_tasks)
# Original SentEval tasks
if name == 'CR':
self.evaluation = CREval(tpath + '/downstream/CR', seed=self.params.seed)
elif name == 'MR':
self.evaluation = MREval(tpath + '/downstream/MR', seed=self.params.seed)
elif name == 'MPQA':
self.evaluation = MPQAEval(tpath + '/downstream/MPQA', seed=self.params.seed)
elif name == 'SUBJ':
self.evaluation = SUBJEval(tpath + '/downstream/SUBJ', seed=self.params.seed)
elif name == 'SST2':
self.evaluation = SSTEval(tpath + '/downstream/SST/binary', nclasses=2, seed=self.params.seed)
elif name == 'SST5':
self.evaluation = SSTEval(tpath + '/downstream/SST/fine', nclasses=5, seed=self.params.seed)
elif name == 'TREC':
self.evaluation = TRECEval(tpath + '/downstream/TREC', seed=self.params.seed)
elif name == 'MRPC':
self.evaluation = MRPCEval(tpath + '/downstream/MRPC', seed=self.params.seed)
elif name == 'SICKRelatedness':
self.evaluation = SICKRelatednessEval(tpath + '/downstream/SICK', seed=self.params.seed)
elif name == 'STSBenchmark':
self.evaluation = STSBenchmarkEval(tpath + '/downstream/STS/STSBenchmark', seed=self.params.seed)
elif name == 'SICKEntailment':
self.evaluation = SICKEntailmentEval(tpath + '/downstream/SICK', seed=self.params.seed)
elif name == 'SNLI':
self.evaluation = SNLIEval(tpath + '/downstream/SNLI', seed=self.params.seed)
elif name in ['STS12', 'STS13', 'STS14', 'STS15', 'STS16']:
fpath = name + '-en-test'
self.evaluation = eval(name + 'Eval')(tpath + '/downstream/STS/' + fpath, seed=self.params.seed)
elif name == 'ImageCaptionRetrieval':
self.evaluation = ImageCaptionRetrievalEval(tpath + '/downstream/COCO', seed=self.params.seed)
# Probing Tasks
elif name == 'Length':
self.evaluation = LengthEval(tpath + '/probing', seed=self.params.seed)
elif name == 'WordContent':
self.evaluation = WordContentEval(tpath + '/probing', seed=self.params.seed)
elif name == 'Depth':
self.evaluation = DepthEval(tpath + '/probing', seed=self.params.seed)
elif name == 'TopConstituents':
self.evaluation = TopConstituentsEval(tpath + '/probing', seed=self.params.seed)
elif name == 'BigramShift':
self.evaluation = BigramShiftEval(tpath + '/probing', seed=self.params.seed)
elif name == 'Tense':
self.evaluation = TenseEval(tpath + '/probing', seed=self.params.seed)
elif name == 'SubjNumber':
self.evaluation = SubjNumberEval(tpath + '/probing', seed=self.params.seed)
elif name == 'ObjNumber':
self.evaluation = ObjNumberEval(tpath + '/probing', seed=self.params.seed)
elif name == 'OddManOut':
self.evaluation = OddManOutEval(tpath + '/probing', seed=self.params.seed)
elif name == 'CoordinationInversion':
self.evaluation = CoordinationInversionEval(tpath + '/probing', seed=self.params.seed)
self.params.current_task = name
self.evaluation.do_prepare(self.params, self.prepare)
self.results = self.evaluation.run(self.params, self.batcher)
return self.results