in scripts/regenerate_paws_kshot.py [0:0]
def main():
for dataset in ['qqp', 'wiki']:
id_to_example = {}
with open(os.path.join('data/paws', dataset, 'train.tsv')) as f:
for i, line in enumerate(f):
if i == 0:
header = line.strip()
else:
ex_id = line.split('\t')[0]
id_to_example[ex_id] = line.strip()
for split in ['train', 'dev']:
for seed in [13, 21, 42, 87, 100]:
id_list_fn = os.path.join('paws_kshot_splits',
f'paws-{dataset}',
f'16-{seed}',
f'{split}_ids.tsv')
cur_out_lines = [header]
with open(id_list_fn) as f:
for line in f:
cur_out_lines.append(id_to_example[line.strip()])
cur_out_dir = os.path.join('data/lm-bff/data/k-shot',
f'paws-{dataset}',
f'16-{seed}')
if split == 'train':
os.makedirs(cur_out_dir)
with open(os.path.join(cur_out_dir, f'{split}.tsv'), 'w') as f:
for line in cur_out_lines:
print(line, file=f)