benchmarks/utils.py (47 lines of code) (raw):
import timeit
import numpy as np
import datasets
from datasets.arrow_writer import ArrowWriter
from datasets.features.features import _ArrayXD
def get_duration(func):
def wrapper(*args, **kwargs):
starttime = timeit.default_timer()
_ = func(*args, **kwargs)
delta = timeit.default_timer() - starttime
return delta
wrapper.__name__ = func.__name__
return wrapper
def generate_examples(features: dict, num_examples=100, seq_shapes=None):
dummy_data = []
seq_shapes = seq_shapes or {}
for i in range(num_examples):
example = {}
for col_id, (k, v) in enumerate(features.items()):
if isinstance(v, _ArrayXD):
data = np.random.rand(*v.shape).astype(v.dtype)
elif isinstance(v, datasets.Value):
if v.dtype == "string":
data = "The small grey turtle was surprisingly fast when challenged."
else:
data = np.random.randint(10, size=1).astype(v.dtype).item()
elif isinstance(v, datasets.Sequence):
while isinstance(v, datasets.Sequence):
v = v.feature
shape = seq_shapes[k]
data = np.random.rand(*shape).astype(v.dtype)
example[k] = data
dummy_data.append((i, example))
return dummy_data
def generate_example_dataset(dataset_path, features, num_examples=100, seq_shapes=None):
dummy_data = generate_examples(features, num_examples=num_examples, seq_shapes=seq_shapes)
with ArrowWriter(features=features, path=dataset_path) as writer:
for key, record in dummy_data:
example = features.encode_example(record)
writer.write(example)
num_final_examples, num_bytes = writer.finalize()
if not num_final_examples == num_examples:
raise ValueError(
f"Error writing the dataset, wrote {num_final_examples} examples but should have written {num_examples}."
)
dataset = datasets.Dataset.from_file(filename=dataset_path, info=datasets.DatasetInfo(features=features))
return dataset