petastorm/reader_impl/pytorch_shuffling_buffer.py (123 lines of code) (raw):

# Copyright (c) 2017-2020 Uber Technologies, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import abc import six import torch @six.add_metaclass(abc.ABCMeta) class BatchedShufflingBufferBase(object): """Shuffling implements a shuffling algorithm. Items can be added to the shuffling buffer and removed in a different order as defined by the concrete shuffling algorithm. A shuffling buffer is intended to be used from a single thread, hence, not thread safe. Functionality is similar to ShufflingBufferBase except operations are batched and based on PyTorch.""" def __init__(self, batch_size=1): self._keys = None self.batch_size = batch_size def add_many(self, items): items = [torch.as_tensor(v) for v in items] return self._add_many(items) @abc.abstractmethod def _add_many(self, items): """Adds multiple items to the buffer. :param items: items to be added to the shuffling buffer. :return: None """ @abc.abstractmethod def retrieve(self): """Selects an batch of items from the buffer and returns the batch to the caller. The items are removed from the buffer. :return: The selected batch. """ @abc.abstractmethod def can_add(self): """Checks the state of the buffer and returns whether a new item can be added to the buffer at the time. :return: A boolean indicating whether an item can be added to the buffer at the time. """ @abc.abstractmethod def can_retrieve(self): """Checks the state of the buffer and returns whether a batch can be removed from the buffer.. :return: A boolean indicating whether an batch can be returned from the buffer at the time. """ @abc.abstractproperty def size(self): """Returns the number of elements currently present in the buffer. :return: number of elements currently present in the buffer """ @abc.abstractmethod def finish(self): """Call this method when no more :func:`add_many` calls will be made. This allows a user to deplete the buffer. Typically during last epoch. Otherwise, we would always have leftovers in the buffer at the end of the lifecycle. :return: number of elements currently present in the buffer """ class BatchedNoopShufflingBuffer(BatchedShufflingBufferBase): """A 'no-operation' (noop) implementation of a shuffling buffer. Useful in cases where no shuffling is desired, such as test scenarios or iterating over a dataset in a predeterministic order. """ def __init__(self, batch_size=1): super(BatchedNoopShufflingBuffer, self).__init__(batch_size=batch_size) self._size = 0 self._buffer = [] self._done_adding = False self._batch_start_idx = 0 def _add_many(self, items): self._size += len(items[0]) if len(self._buffer) == 0: self._buffer = items else: # Merge with previous rowgroup leftover for i, v in enumerate(items): self._buffer[i] = torch.cat([self._buffer[i][self._batch_start_idx:], v], 0) # Batch start idx starts from 0 for a fresh rowgroup. self._batch_start_idx = 0 def retrieve(self): batch = [] cur_batch_size = min(self._size, self.batch_size) for v in self._buffer: v_batch = v[self._batch_start_idx:self._batch_start_idx+cur_batch_size] batch.append(v_batch) # Increase batch start idx with current batch size self._batch_start_idx += cur_batch_size # Decrease size with current batch size self._size -= cur_batch_size return batch def can_retrieve(self): if not self._done_adding: return self._size >= self.batch_size else: return self._size > 0 def can_add(self): return True @property def size(self): return self._size def finish(self): self._done_adding = True class BatchedRandomShufflingBuffer(BatchedShufflingBufferBase): """ A random shuffling buffer implementation. Items can be added to the buffer and retrieved in a random order. """ def __init__(self, shuffling_buffer_capacity, min_after_retrieve, extra_capacity=1000, batch_size=1): """Initializes a new BatchedRandomShufflingBuffer instance. Items may be retrieved from the buffer once ``min_after_retrieve`` items were added to the queue (indicated by ``can_retrieve``). Items may be added to the buffer as long as the number of items in the buffer (not including the items passed to :func:`add_many`) does not exceed ``shuffling_queue_capacity``. The amount of items in the buffer may actually become more than ``shuffling_buffer_capacity`` since :func:`add_many` is passed a list of items. The *hard limit* on the number of items in the buffer is ``shuffling_buffer_capacity + extra_capacity``. Explanation: This batch loader performs some non-conventional operations: Let's say we enqueued several samples: [1, 2, 3, 4, 5, 6, 7] Now during a retrieve() we sample the order these samples will be retrieved: [2, 4, 5, 1, 3, 0, 6] Once an order has been sampled, we slice the order into batches of ``batch_size`` samples. And index 1 batch at a time: [1, 2, X, 4, X, 6, 7] -> [3, 5] (batch 1) [1, X, X, 4, X, X, 7] -> [6 ,2] (batch 2) We could compress the buffer after every retrieve(), but that would require custom ops. When we call add_many we first rearrange the remaining elements: [1, 4, 7] Then append new elements: [1, 4, 7, 8, 9, 10] After add_many we have to resample a permutation for the buffer. :param shuffling_buffer_capacity: Items may be added to the buffer as long as the amount of items in the buffer does not exceed the value of ``shuffling_queue_capacity`` (not including the items passed to :func:`add_many`). :param min_after_retrieve: Minimal amount of items in the buffer that allows retrieval. This is needed to guarantee good random shuffling of elements. Once :func:`finish` is called, items can be retrieved even if the condition does not hold. :param extra_capacity: The amount of items in the buffer may grow above ``shuffling_buffer_capacity`` (due to a call to :func:`add_many` with a list of items), but must remain under ``extra_capacity``. Should be set to the upper bound of the number of items that can be added in a single call to :func:`add_many` (can be a loose bound). :param batch_size: The number of items to be retrieved for each self.retrieve() call. This also affects the can_add and can can_retrieve accordingly. """ super(BatchedRandomShufflingBuffer, self).__init__(batch_size=batch_size) self._extra_capacity = extra_capacity # Preallocate the shuffling buffer. self._items = None self._shuffling_queue_capacity = shuffling_buffer_capacity self._min_after_dequeue = min_after_retrieve self._size = 0 self._done_adding = False self._random_indices = None self.next_sample_head = 0 def _add_many(self, items): if self._done_adding: raise RuntimeError('Can not call add_many after done_adding() was called.') if not self.can_add(): raise RuntimeError('Can not enqueue. Check the return value of "can_enqueue()" to check if more ' 'items can be added.') expected_size = self._size + len(items[0]) maximal_capacity = self._shuffling_queue_capacity + self._extra_capacity if expected_size > maximal_capacity: raise RuntimeError('Attempt to enqueue more elements than the capacity allows. ' 'Current size: {}, new size {}, maximum allowed: {}'.format(self._size, expected_size, maximal_capacity)) new_capacity = self._shuffling_queue_capacity while new_capacity < expected_size: # Will double capacity until it is large enough to fit new batch new_capacity *= 2 if self._items is None: # Create Buffer: self._items = [] for v in items: self._items.append(torch.empty((new_capacity,) + v.shape[1:], dtype=v.dtype, device=v.device)) if self.next_sample_head > 0: # Before we can append a new batch, we compress the remaining samples for k, v in enumerate(self._items): # We need to clone the right-side to avoid racing conditions self._items[k][:self.size] = self._items[k][self._random_indices[self.next_sample_head:]].clone() self._random_indices = None self.next_sample_head = 0 if new_capacity > self._items[0].shape[0]: for k, v in enumerate(self._items): self._items[k] = torch.empty((new_capacity,) + v.shape[1:], dtype=v.dtype, device=v.device) self._items[k][:self._size] = v[:self._size] # Copy new items over for k, v in enumerate(items): self._items[k][self._size:expected_size] = v self._size = expected_size def retrieve(self): if not self._done_adding and not self.can_retrieve(): raise RuntimeError('Can not dequeue. Check the return value of "can_dequeue()" to check if any ' 'items are available.') batch_size = min(self.batch_size, self._size) if self._random_indices is None: # We randomize the order of all samples ahead of time and then slice it into chunks with ```batch_size``` self.next_sample_head = 0 self._random_indices = torch.randperm(int(self._size), device=self._items[0].device) idx = self._random_indices[self.next_sample_head:self.next_sample_head + batch_size] self.next_sample_head += batch_size sample = [v[idx] for v in self._items] self._size -= batch_size return sample def can_add(self): return self._size < self._shuffling_queue_capacity and not self._done_adding def can_retrieve(self): return self._size >= self._min_after_dequeue + self.batch_size - 1 or (self._done_adding and self._size > 0) @property def size(self): return self._size def finish(self): self._done_adding = True