def get_dataset()

in src/fmeval/data_loaders/util.py [0:0]


def get_dataset(config: DataConfig, num_records: Optional[int] = None) -> ray.data.Dataset:
    """
    Util method to load Ray datasets using an input DataConfig.

    :param config: Input DataConfig
    :param num_records: the number of records to sample from the dataset
    """
    # The following setup is necessary to instruct Ray to preserve the
    # order of records in the datasets
    ctx = ray.data.DataContext.get_current()
    ctx.execution_options.preserve_order = True
    with timed_block(f"Loading dataset {config.dataset_name}", logger):
        data_source = get_data_source(config.dataset_uri)
        data_loader_config = _get_data_loader_config(data_source, config)
        data_loader = _get_data_loader(config.dataset_mime_type)
        data = data_loader.load_dataset(data_loader_config)
        count = data.count()
        util.require(count > 0, "Data has to have at least one record")
        if num_records and num_records > 0:  # pragma: no branch
            # TODO update sampling logic - current logic is biased towards first MAX_ROWS_TO_TAKE rows
            num_records = min(num_records, count)
            # We are using to_pandas, sampling with Pandas dataframe, and then converting back to Ray Dataset to use
            # Pandas DataFrame's ability to sample deterministically. This is temporary workaround till Ray solves this
            # issue: https://github.com/ray-project/ray/issues/40406
            if count > MAX_ROWS_TO_TAKE:
                # If count is larger than 100000, we take the first 100000 row, and then sample from that to
                # maintain deterministic behaviour. We are using take_batch to get a pandas dataframe of size
                # MAX_ROWS_TO_TAKE when the size of original dataset is greater than MAX_ROWS_TO_TAKE. This is to avoid
                # failures in driver node by pulling too much data.
                pandas_df = data.take_batch(batch_size=MAX_ROWS_TO_TAKE, batch_format="pandas")
            else:
                pandas_df = data.to_pandas()
            sampled_df = pandas_df.sample(num_records, random_state=SEED)
            data = ray.data.from_pandas(sampled_df)
        data = data.repartition(get_num_actors() * PARTITION_MULTIPLIER).materialize()
    return data