in timm/data/naflex_dataset.py [0:0]
def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]:
"""Iterates through pre-calculated batches for the current epoch.
Yields:
Tuple of (input_dict, targets) for each batch.
"""
worker_info = torch.utils.data.get_worker_info()
num_workers = worker_info.num_workers if worker_info else 1
worker_id = worker_info.id if worker_info else 0
# Distribute pre-calculated batches among workers for this rank
# Each worker processes a slice of the batches prepared in _prepare_epoch_batches
batches_for_worker = self._epoch_batches[worker_id::num_workers]
for seq_len, indices in batches_for_worker:
if not indices: # Skip if a batch ended up with no indices (shouldn't happen often)
continue
# Select patch size for this batch
patch_idx = 0
if self.variable_patch_size:
# Use torch multinomial for weighted random choice
patch_idx = torch.multinomial(torch.tensor(self.patch_size_probs), 1).item()
# Get the pre-initialized transform and patchifier using patch_idx
transform_key = (seq_len, patch_idx)
transform = self.transforms.get(transform_key)
batch_patchifier = self.patchifiers[patch_idx]
batch_imgs = []
batch_targets = []
for idx in indices:
try:
# Get original image and label from map-style dataset
img, label = self.base_dataset[idx]
# Apply transform if available
# Handle cases where transform might return None or fail
processed_img = transform(img) if transform else img
if processed_img is None:
warnings.warn(f"Transform returned None for index {idx}. Skipping sample.")
continue
batch_imgs.append(processed_img)
batch_targets.append(label)
except IndexError:
warnings.warn(f"IndexError encountered for index {idx} (possibly due to padding/repeated indices). Skipping sample.")
continue
except Exception as e:
# Log other potential errors during data loading/processing
warnings.warn(f"Error processing sample index {idx}. Error: {e}. Skipping sample.")
continue # Skip problematic sample
if self.mixup_fn is not None:
batch_imgs, batch_targets = self.mixup_fn(batch_imgs, batch_targets)
batch_imgs = [batch_patchifier(img) for img in batch_imgs]
batch_samples = list(zip(batch_imgs, batch_targets))
if batch_samples: # Only yield if we successfully processed samples
# Collate the processed samples into a batch
yield self.collate_fns[seq_len](batch_samples)