in src/fmeval/data_loaders/util.py [0:0]
def get_dataset(config: DataConfig, num_records: Optional[int] = None) -> ray.data.Dataset:
"""
Util method to load Ray datasets using an input DataConfig.
:param config: Input DataConfig
:param num_records: the number of records to sample from the dataset
"""
# The following setup is necessary to instruct Ray to preserve the
# order of records in the datasets
ctx = ray.data.DataContext.get_current()
ctx.execution_options.preserve_order = True
with timed_block(f"Loading dataset {config.dataset_name}", logger):
data_source = get_data_source(config.dataset_uri)
data_loader_config = _get_data_loader_config(data_source, config)
data_loader = _get_data_loader(config.dataset_mime_type)
data = data_loader.load_dataset(data_loader_config)
count = data.count()
util.require(count > 0, "Data has to have at least one record")
if num_records and num_records > 0: # pragma: no branch
# TODO update sampling logic - current logic is biased towards first MAX_ROWS_TO_TAKE rows
num_records = min(num_records, count)
# We are using to_pandas, sampling with Pandas dataframe, and then converting back to Ray Dataset to use
# Pandas DataFrame's ability to sample deterministically. This is temporary workaround till Ray solves this
# issue: https://github.com/ray-project/ray/issues/40406
if count > MAX_ROWS_TO_TAKE:
# If count is larger than 100000, we take the first 100000 row, and then sample from that to
# maintain deterministic behaviour. We are using take_batch to get a pandas dataframe of size
# MAX_ROWS_TO_TAKE when the size of original dataset is greater than MAX_ROWS_TO_TAKE. This is to avoid
# failures in driver node by pulling too much data.
pandas_df = data.take_batch(batch_size=MAX_ROWS_TO_TAKE, batch_format="pandas")
else:
pandas_df = data.to_pandas()
sampled_df = pandas_df.sample(num_records, random_state=SEED)
data = ray.data.from_pandas(sampled_df)
data = data.repartition(get_num_actors() * PARTITION_MULTIPLIER).materialize()
return data