in src/run_paraphrase.py [0:0]
def read_data(dataset_name, roberta, train=False, kshot_seed=None):
if dataset_name == 'qqp':
if kshot_seed:
if train:
return read_qqp(QQP_TRAIN_PATTERN.format(kshot_seed), roberta)
else:
return read_qqp(QQP_DEV_PATTERN.format(kshot_seed), roberta)
return read_qqp(QQP_TEST_FILE, roberta)
elif dataset_name == 'mrpc':
if kshot_seed:
if train:
return read_mrpc(MRPC_TRAIN_PATTERN.format(kshot_seed), roberta)
else:
return read_mrpc(MRPC_DEV_PATTERN.format(kshot_seed), roberta)
return read_mrpc(MRPC_TEST_FILE, roberta)
elif dataset_name == 'paws-wiki':
if kshot_seed:
if train:
return read_paws(PAWS_WIKI_TRAIN_PATTERN.format(kshot_seed), roberta)
else:
return read_paws(PAWS_WIKI_DEV_PATTERN.format(kshot_seed), roberta)
return read_paws(PAWS_WIKI_TEST_FILE, roberta)
elif dataset_name == 'paws-qqp':
if kshot_seed:
if train:
return read_paws(PAWS_QQP_TRAIN_PATTERN.format(kshot_seed), roberta)
else:
return read_paws(PAWS_QQP_DEV_PATTERN.format(kshot_seed), roberta)
return read_paws(PAWS_QQP_TEST_FILE, roberta)
else:
raise NotImplementedError
return data