in src/sagemaker_sklearn_extension/externals/read_data.py [0:0]
def _add_batch(self, batch):
"""Adds batch or truncated batch to the concatenation list.
A batch is truncated if adding it would make the final array exceed the maximum target size.
Parameters
----------
batch : mlio Example
An MLIO batch returned by CsvReader, with data encoded as strings.as an Example class. Can be used to
easily iterate over columns of data in batch.
Returns
-------
bool
True if adding batches should continue, False otherwise.
"""
# Perform initialization on first batch.
if not self._initialized:
self._initialize_state(batch)
# Construct numpy representation for data in batch.
features_array_data = self._construct_features_array_data(batch)
target_array_data = self._construct_target_array_data(batch)
# Update size estimation variables.
batch_nbytes_estimate, total_nbytes_estimate = self._update_array_size_estimate(
features_array_data, target_array_data
)
# If the resulting array will be too large, truncate the last batch so that it fits.
should_continue = True
if total_nbytes_estimate > self.max_size_in_bytes:
batch_bytes_to_keep = batch_nbytes_estimate - (total_nbytes_estimate - self.max_size_in_bytes)
fraction_of_batch_rows_to_keep = batch_bytes_to_keep / batch_nbytes_estimate
n_rows_to_keep = int(fraction_of_batch_rows_to_keep * self._n_rows(features_array_data))
if n_rows_to_keep > 0:
features_array_data = self._resize_features_array_data(features_array_data, n_rows_to_keep)
if self._split_target:
target_array_data = self._resize_target_array_data(target_array_data, n_rows_to_keep)
should_continue = False
self._extend_features_batches(features_array_data)
if self._split_target:
self._extend_target_batches(target_array_data)
return should_continue