in vision/m4/training/dataset.py [0:0]
def __iter__(self):
# Dummy dataset idx used for compatibility with CustomChainDataset
dummy_dataset_idx = 0
if self.rank is None or self.world_size is None:
raise ValueError("rank and world_size must be provided")
# Get worker indices & details for resuming...
worker_indices, worker_id = self._get_worker_indices()
num_worker_indices = len(worker_indices)
# Set start idx of loop based on `self.worker_resume_idxs`
map_start_idx, last_key_idx, overflow_batch = self.worker_idx_tracker.get(worker_id, (0, -1, {}))
self.rng_seed = [self.seed, self.epoch, self.rank, worker_id, map_start_idx]
self.rng = np.random.default_rng(seed=self.rng_seed)
for i in range(map_start_idx, num_worker_indices, self.mapper_batch_size):
# Set seed for the worker according to worker index and the index and then reset it work
# This needs to be done so that torch random crop is deterministic
rng_state = torch.get_rng_state()
torch.manual_seed(f"{self.seed}{worker_id}{i}")
# Feed `worker_indices[i]` to mapper to ensure "deterministic randomness" that we don't have to track...
curr_mapped_batch = self.mapper(
self.dataset[worker_indices[i : i + self.mapper_batch_size]],
prefix_seed=(self.seed, self.epoch, self.rank, worker_id, i),
)
torch.set_rng_state(rng_state)
keys = list(curr_mapped_batch.keys())
overflow_batch_keys = overflow_batch.keys()
# Check if overflow from previous batches is left, if yes, add it to the current batch
# Specifically, we should prepend this overflow batch so as it goes out first and
# current batch possibly becomes next overflow batch
if len(overflow_batch_keys) > 0:
if sorted(overflow_batch_keys) != sorted(keys):
raise ValueError(
"Overflow batch keys not equal to current keys. Make sure mapper is always returning"
" dictionary with the same keys. "
f"Overflow: {sorted(overflow_batch_keys)}, Mapping: {sorted(keys)}"
)
else:
mapped_batch = {}
if "pixel_values" in overflow_batch or "pixel_values" in curr_mapped_batch:
total_batch_size = overflow_batch["input_ids"].size(0) + curr_mapped_batch["input_ids"].size(0)
max_num_images = max(
overflow_batch["pixel_values"].size(1) if "pixel_values" in overflow_batch else 0,
curr_mapped_batch["pixel_values"].size(1) if "pixel_values" in curr_mapped_batch else 0,
)
max_height = max(
overflow_batch["pixel_values"].size(3) if "pixel_values" in overflow_batch else 0,
curr_mapped_batch["pixel_values"].size(3) if "pixel_values" in curr_mapped_batch else 0,
)
max_width = max(
overflow_batch["pixel_values"].size(4) if "pixel_values" in overflow_batch else 0,
curr_mapped_batch["pixel_values"].size(4) if "pixel_values" in curr_mapped_batch else 0,
)
padded_image_tensor = torch.zeros(total_batch_size, max_num_images, 3, max_height, max_width)
padded_pixel_attention_masks = torch.zeros(
total_batch_size, max_num_images, max_height, max_width, dtype=torch.bool
)
start = 0
for batch in [overflow_batch, curr_mapped_batch]:
if "pixel_values" not in batch:
continue
px = batch["pixel_values"]
px_attn_mask = batch["pixel_attention_mask"]
end = start + px.size(0)
padded_image_tensor[start:end, :, :, : px.size(3), : px.size(4)] = px
padded_pixel_attention_masks[start:end, :, : px.size(3), : px.size(4)] = px_attn_mask
start += px.size(0)
mapped_batch["pixel_values"] = padded_image_tensor.contiguous()
mapped_batch["pixel_attention_mask"] = padded_pixel_attention_masks.contiguous()
for key in keys:
if key in ["pixel_values", "pixel_attention_mask"]:
continue
mapped_batch[key] = torch.cat([overflow_batch[key], curr_mapped_batch[key]], dim=0)
previous_overflow_batch = copy.deepcopy(overflow_batch)
overflow_batch = {}
else:
previous_overflow_batch = {}
mapped_batch = curr_mapped_batch
first_key = keys[0]
mapped_batch_length = len(mapped_batch[first_key])
if self.shuffle_after_packing:
indices = list(range(mapped_batch_length))
self.rng.shuffle(indices)
for key in mapped_batch.keys():
mapped_batch[key] = mapped_batch[key][indices, ...]
if mapped_batch_length < self.batch_size:
# We need to add more data to this batch to make it of size `self.batch_size`
# Just setting mapped_batch to overflow_batch should be enough as the next iteration
# will add more data to it
overflow_batch = mapped_batch
else:
# Now, yield batches of size batch_size from the mapped batch
for key_idx in range(0, mapped_batch_length, self.batch_size):
# Set "reproducible" randomness
self.rng_seed = [self.seed, self.epoch, self.rank, worker_id, i, key_idx]
self.rng = np.random.default_rng(seed=self.rng_seed)
if i == map_start_idx and key_idx <= last_key_idx:
# Handle Resume (only for "first" loop iteration) advance random state until `last_key_idx`
self.rng.random()
else:
overflow_batch = {key: mapped_batch[key][key_idx : key_idx + self.batch_size] for key in keys}
if len(overflow_batch[first_key]) != self.batch_size:
# Last batch
break
else:
dataset_state = {
"worker_idx": worker_id,
"map_start_idx": i,
"last_key_idx": key_idx,
"previous_overflow_batch": previous_overflow_batch,
}
yield dummy_dataset_idx, self.dataset_name.name.lower(), dataset_state, overflow_batch
overflow_batch = {}