petastorm/pyarrow_helpers/batching_table_queue.py (30 lines of code) (raw):
# Copyright (c) 2017-2018 Uber Technologies, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import deque
import pyarrow as pa
class BatchingTableQueue(object):
def __init__(self, batch_size):
"""The class is a FIFO queue. Arrow tables are added to the queue. When read, rows are regrouped into Arrow
tables of a fixed size specified during construction of the object. The order of the rows in the output tables
is the same as the order of the rows in the input tables.
:param batch_size: number of rows in tables that will be returned by the ``get`` method.
"""
self._batch_size = batch_size
self._buffer = deque()
self._head_idx = 0
self._cumulative_len = 0
def put(self, table):
"""Adds a table to the queue. All tables added during lifetime of an instance must have the same schema.
:param table: An instance of a pyarrow table.
:return: None
"""
# We store a list of arrow batches. When retrieving, we consume parts or entire batches, until batch_size of
# rows are acquired.
record_batches = table.to_batches()
for record_batch in record_batches:
self._buffer.append(record_batch)
self._cumulative_len += record_batch.num_rows
def empty(self):
"""Checks if more tables can be returned by get. If the number of rows in the internal buffer is less then
``batch_size``, empty would return False.
"""
return self._head_idx + self._batch_size > self._cumulative_len
def get(self):
"""Return a table with ``batch_size`` number of rows.
:return: An instance of an Arrow table with exactly ``batch_size`` rows.
"""
assert not self.empty()
# head_idx points to the next row in the buffer[0] batch to be consumed.
# Accumulate selices/full batches until result_rows reaches desired batch_size.
# Pop left of the deque once exhausted all rows there.
result = []
result_rows = 0
while result_rows < self._batch_size and self._cumulative_len > 0:
head = self._buffer[0]
piece = head[self._head_idx:self._head_idx + self._batch_size - result_rows]
self._head_idx += piece.num_rows
result_rows += piece.num_rows
result.append(piece)
if head.num_rows == self._head_idx:
self._head_idx = 0
self._buffer.popleft()
self._cumulative_len -= head.num_rows
return pa.Table.from_batches(result)