def __iter__()

in timm/data/naflex_loader.py [0:0]


    def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]:
        """Iterate through the loader with prefetching and normalization.

        Yields:
            Tuple of (input_dict, targets) with normalized patches.
        """
        first = True
        if self.is_cuda:
            stream = torch.cuda.Stream(device=self.device)
            stream_context = partial(torch.cuda.stream, stream=stream)
        elif self.is_npu:
            stream = torch.npu.Stream(device=self.device)
            stream_context = partial(torch.npu.stream, stream=stream)
        else:
            stream = None
            stream_context = suppress

        for next_input_dict, next_target in self.loader:
            with stream_context():
                # Move all tensors in input_dict to device
                for k, v in next_input_dict.items():
                    if isinstance(v, torch.Tensor):
                        dtype = self.img_dtype if k == 'patches' else None
                        next_input_dict[k] = next_input_dict[k].to(
                            device=self.device,
                            non_blocking=True,
                            dtype=dtype,
                        )

                next_target = next_target.to(device=self.device, non_blocking=True)

                # Normalize patch values - handle both [B, N, P*P*C] and [B, N, Ph, Pw, C] formats
                patches_tensor = next_input_dict['patches']
                original_shape = patches_tensor.shape

                if patches_tensor.ndim == 3:
                    # Format: [B, N, P*P*C] - flattened patches
                    batch_size, num_patches, patch_pixels = original_shape
                    # To [B*N, P*P, C] for normalization and erasing
                    patches = patches_tensor.view(batch_size, num_patches, -1, self.channels)
                elif patches_tensor.ndim == 5:
                    # Format: [B, N, Ph, Pw, C] - unflattened patches (variable patch size mode)
                    batch_size, num_patches, patch_h, patch_w, channels = original_shape
                    assert channels == self.channels, f"Expected {self.channels} channels, got {channels}"
                    # To [B*N, Ph*Pw, C] for normalization and erasing
                    patches = patches_tensor.view(batch_size, num_patches, -1, self.channels)
                else:
                    raise ValueError(f"Unexpected patches tensor dimensions: {patches_tensor.ndim}. Expected 3 or 5.")

                # Apply normalization
                patches = patches.sub(self.mean).div(self.std)

                if self.random_erasing is not None:
                    patches = self.random_erasing(
                        patches,
                        patch_coord=next_input_dict['patch_coord'],
                        patch_valid=next_input_dict.get('patch_valid', None),
                    )

                # Reshape back to original format
                next_input_dict['patches'] = patches.view(original_shape)

            if not first:
                yield input_dict, target
            else:
                first = False

            if stream is not None:
                if self.is_cuda:
                    torch.cuda.current_stream(device=self.device).wait_stream(stream)
                elif self.is_npu:
                    torch.npu.current_stream(device=self.device).wait_stream(stream)

            input_dict = next_input_dict
            target = next_target

        yield input_dict, target