in timm/data/naflex_loader.py [0:0]
def __iter__(self) -> Iterator[Tuple[Dict[str, torch.Tensor], torch.Tensor]]:
"""Iterate through the loader with prefetching and normalization.
Yields:
Tuple of (input_dict, targets) with normalized patches.
"""
first = True
if self.is_cuda:
stream = torch.cuda.Stream(device=self.device)
stream_context = partial(torch.cuda.stream, stream=stream)
elif self.is_npu:
stream = torch.npu.Stream(device=self.device)
stream_context = partial(torch.npu.stream, stream=stream)
else:
stream = None
stream_context = suppress
for next_input_dict, next_target in self.loader:
with stream_context():
# Move all tensors in input_dict to device
for k, v in next_input_dict.items():
if isinstance(v, torch.Tensor):
dtype = self.img_dtype if k == 'patches' else None
next_input_dict[k] = next_input_dict[k].to(
device=self.device,
non_blocking=True,
dtype=dtype,
)
next_target = next_target.to(device=self.device, non_blocking=True)
# Normalize patch values - handle both [B, N, P*P*C] and [B, N, Ph, Pw, C] formats
patches_tensor = next_input_dict['patches']
original_shape = patches_tensor.shape
if patches_tensor.ndim == 3:
# Format: [B, N, P*P*C] - flattened patches
batch_size, num_patches, patch_pixels = original_shape
# To [B*N, P*P, C] for normalization and erasing
patches = patches_tensor.view(batch_size, num_patches, -1, self.channels)
elif patches_tensor.ndim == 5:
# Format: [B, N, Ph, Pw, C] - unflattened patches (variable patch size mode)
batch_size, num_patches, patch_h, patch_w, channels = original_shape
assert channels == self.channels, f"Expected {self.channels} channels, got {channels}"
# To [B*N, Ph*Pw, C] for normalization and erasing
patches = patches_tensor.view(batch_size, num_patches, -1, self.channels)
else:
raise ValueError(f"Unexpected patches tensor dimensions: {patches_tensor.ndim}. Expected 3 or 5.")
# Apply normalization
patches = patches.sub(self.mean).div(self.std)
if self.random_erasing is not None:
patches = self.random_erasing(
patches,
patch_coord=next_input_dict['patch_coord'],
patch_valid=next_input_dict.get('patch_valid', None),
)
# Reshape back to original format
next_input_dict['patches'] = patches.view(original_shape)
if not first:
yield input_dict, target
else:
first = False
if stream is not None:
if self.is_cuda:
torch.cuda.current_stream(device=self.device).wait_stream(stream)
elif self.is_npu:
torch.npu.current_stream(device=self.device).wait_stream(stream)
input_dict = next_input_dict
target = next_target
yield input_dict, target