in torchbenchmark/util/framework/transformers/text_classification/dataset.py [0:0]
def prep_dataset(hf_args):
# Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
# or specify a GLUE benchmark task (the dataset will be downloaded automatically from the datasets Hub).
# For CSV/JSON files, this script will use as labels the column called 'label' and as pair of sentences the
# sentences in columns called 'sentence1' and 'sentence2' if such column exists or the first two columns not named
# label if at least two columns are provided.
# If the CSVs/JSONs contain only one non-label column, the script does single sentence classification on this
# single column. You can easily tweak this behavior (see below)
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
# download the dataset.
if hf_args.task_name is not None:
# Downloading and loading a dataset from the hub.
raw_datasets = load_dataset("glue", hf_args.task_name)
else:
# Loading a dataset from your local files.
# CSV/JSON training and evaluation files are needed.
data_files = {"train": hf_args.train_file, "validation": hf_args.validation_file}
# Get the test dataset: you can provide your own CSV/JSON test file (see below)
# when you use `do_predict` without specifying a GLUE benchmark task.
if hf_args.do_predict:
if hf_args.test_file is not None:
train_extension = hf_args.train_file.split(".")[-1]
test_extension = hf_args.test_file.split(".")[-1]
assert (
test_extension == train_extension
), "`test_file` should have the same extension (csv or json) as `train_file`."
data_files["test"] = hf_args.test_file
else:
raise ValueError("Need either a GLUE task or a test file for `do_predict`.")
# for key in data_files.keys():
# logger.info(f"load a local file for {key}: {data_files[key]}")
if hf_args.train_file.endswith(".csv"):
# Loading a dataset from local csv files
raw_datasets = load_dataset("csv", data_files=data_files, cache_dir=hf_args.cache_dir)
else:
# Loading a dataset from local json files
raw_datasets = load_dataset("json", data_files=data_files, cache_dir=hf_args.cache_dir)
# See more about loading any type of standard or custom dataset at
# https://huggingface.co/docs/datasets/loading_datasets.html.
return raw_datasets