in ludwig/utils/batcher.py [0:0]
def next_batch(self):
if self.last_batch():
if self.should_shuffle:
self.shuffle(self.buckets_idcs)
self.reset()
self.epoch += 1
if self.ignore_last:
idcs_below_size = self.indices + self.batch_size < self.bucket_sizes
else:
idcs_below_size = self.indices < self.bucket_sizes
i = np.random.choice(
np.arange(0, len(self.buckets_idcs))[idcs_below_size])
selected_bucket = self.buckets_idcs[i]
selected_idcs = selected_bucket[
self.indices[i]:self.indices[i] + self.batch_size]
sub_batch = {}
for key in self.dataset.get_dataset():
if key == self.bucketing_field and self.should_trim:
selected_samples = self.dataset.get(key, selected_idcs)
max_length = np.sign(selected_samples).sum(axis=1).max()
if self.trim_side == 'right':
sub_batch[key] = selected_samples[:, :max_length]
elif self.trim_side == 'left':
sub_batch[key] = selected_samples[:, -max_length:]
else:
raise ValueError('Invalid trim side:', self.trim_side)
else:
sub_batch[key] = self.dataset.get(key, selected_idcs)
self.indices[i] += self.batch_size
self.step += 1
return sub_batch