modules/SwissArmyTransformer/sat/data_utils/hf_dataset.py (30 lines of code) (raw):
# -*- encoding: utf-8 -*-
# @File : get_dataset.py
# @Time : 2021/12/14
# @Author : Zhuoyi Yang
# @Contact : yangzhuo18@mails.tsinghua.edu.cn
import os
import datasets
from datasets import load_dataset
from sat.helpers import print_rank0
def parse_huggingface_path(path):
if path.startswith('hf://'):
path = path[5:]
names = path.split('/')
first_name = names[0]
second_name = names[1] if len(names) >= 2 and names[1] != '*' else None
split = names[2] if len(names) >= 3 else 'train'
return first_name, second_name, split
def load_hf_dataset(path, process_fn, columns=None, cache_dir='~/.cache/huggingface/datasets', offline=False, transformer_name = None, rebuild=False):
dataset_name, sub_name, split = parse_huggingface_path(path)
datasets.config.HF_DATASETS_OFFLINE = int(offline)
if transformer_name:
dataset_path = cache_dir + '/' + dataset_name + "_" + sub_name + "_" + split + "_" + transformer_name + ".data"
else:
dataset_path = None
if dataset_path and os.path.exists(dataset_path) and not rebuild:
dataset = datasets.load_from_disk(dataset_path)
else:
dataset = load_dataset(dataset_name, sub_name, cache_dir=cache_dir, split=split,
download_config=datasets.utils.DownloadConfig(max_retries=20)) # TODO
# dataset = dataset.filter(lambda example, indice: indice % 100 == 0, with_indices=True)
print_rank0(f'> Preprocessing the {dataset_name} by process_fn... Next time will return cached files.\n> Pass "rebuild=True" to load_hf_dataset if change process_fn. Change "transformer_name" for different tokenizers or models.')
dataset = dataset.map(process_fn, batched=False, load_from_cache_file=True)
if dataset_path:
dataset.save_to_disk(dataset_path)
dataset.set_format(type='torch', columns=columns)
return dataset