def convert_arrow_table_to_numpy_dict()

in petastorm/arrow_reader_worker.py [0:0]


def convert_arrow_table_to_numpy_dict(result_table, schema):
    """Convert PyArrow table columns to NumPy arrays.

    Converts PyArrow table columns into a dictionary of NumPy arrays, handling
    different data types appropriately. Strings are converted to unicode arrays,
    lists are converted to matrices, and other types are converted directly.

    Args:
        result_table: PyArrow Table to convert
        schema: Petastorm schema containing field information

    Returns:
        dict: Dictionary mapping column names to NumPy arrays

    Raises:
        RuntimeError: If list columns have inconsistent lengths
    """
    result_dict = dict()
    for column_name in result_table.column_names:
        column = result_table.column(column_name)
        # Assume we get only one chunk since reader worker reads one rowgroup at a time

        # `to_pandas` works slower when called on the entire `data` rather directly on a chunk.
        if result_table.column(0).num_chunks == 1:
            column_as_pandas = column.chunks[0].to_pandas()
        else:
            column_as_pandas = column.to_pandas()

        # pyarrow < 0.15.0 would always return a numpy array. Starting 0.15 we get pandas series, hence we
        # convert it into numpy array
        if isinstance(column_as_pandas, pd.Series):
            column_as_numpy = column_as_pandas.values
        else:
            column_as_numpy = column_as_pandas

        if pa.types.is_string(column.type):
            result_dict[column_name] = column_as_numpy.astype(np.unicode_)
        elif pa.types.is_list(column.type):
            # Assuming all lists are of the same length, hence we can collate them into a matrix
            list_of_lists = column_as_numpy
            try:
                col_data = np.vstack(list_of_lists.tolist())
                shape = schema.fields[column_name].shape
                if len(shape) > 1:
                    col_data = col_data.reshape((len(list_of_lists),) + shape)
                result_dict[column_name] = col_data

            except ValueError:
                raise RuntimeError('Length of all values in column \'{}\' are expected to be the same length. '
                                   'Got the following set of lengths: \'{}\''
                                   .format(column_name,
                                           ', '.join(str(value.shape[0]) for value in list_of_lists)))
        else:
            result_dict[column_name] = column_as_numpy

    return result_dict