in timm/data/readers/reader_tfds.py [0:0]
def __iter__(self):
if self.ds is None or self.reinit_each_iter:
self._lazy_init()
# Compute a rounded up sample count that is used to:
# 1. make batches even cross workers & replicas in distributed validation.
# This adds extra samples and will slightly alter validation results.
# 2. determine loop ending condition in training w/ repeat enabled so that only full batch_size
# batches are produced (underlying tfds iter wraps around)
target_sample_count = self._num_samples_per_worker()
# Iterate until exhausted or sample count hits target when training (ds.repeat enabled)
sample_count = 0
for sample in self.ds:
input_data = sample[self.input_key]
if self.input_img_mode:
if self.input_img_mode == 'L' and input_data.ndim == 3:
input_data = input_data[:, :, 0]
input_data = Image.fromarray(input_data, mode=self.input_img_mode)
target_data = sample[self.target_key]
if self.target_img_mode:
# dense pixel target
target_data = Image.fromarray(target_data, mode=self.target_img_mode)
elif self.remap_class:
target_data = self.class_to_idx[target_data]
yield input_data, target_data
sample_count += 1
if self.is_training and sample_count >= target_sample_count:
# Need to break out of loop when repeat() is enabled for training w/ oversampling
# this results in extra samples per epoch but seems more desirable than dropping
# up to N*J batches per epoch (where N = num distributed processes, and J = num worker processes)
break
# Pad across distributed nodes (make counts equal by adding samples)
if not self.is_training and self.dist_num_replicas > 1 and self.subsplit is not None and \
0 < sample_count < target_sample_count:
# Validation batch padding only done for distributed training where results are reduced across nodes.
# For single process case, it won't matter if workers return different batch sizes.
# If using input_context or % based splits, sample count can vary significantly across workers and this
# approach should not be used (hence disabled if self.subsplit isn't set).
while sample_count < target_sample_count:
yield input_data, target_data # yield prev sample again
sample_count += 1