in vision/m4/training/dataset.py [0:0]
def __iter__(self):
# Dummy dataset idx used for compatibility with CustomChainDataset
dummy_dataset_idx = 0
if self.rank is None or self.world_size is None:
raise ValueError("rank and world_size must be provided")
# Relic from previous implementation - but needed for rng seed
worker_id, worker_total_num = self._get_worker_id_and_worker_total_num()
# Relic from previous implementation - but needed for rng seed
map_start_idx, last_key_idx, overflow_batch = self.worker_idx_tracker.get(worker_id, (0, -1, {}))
# Relic from previous implementation - but needed for rng seed
i = map_start_idx
# Initialize rng_seed
self.rng_seed = [self.seed, self.epoch, self.rank, worker_id, i]
self.rng = np.random.default_rng(seed=self.rng_seed)
while True:
# Set seed for the worker according to worker index and the index and then reset it work
# This needs to be done so that torch random crop is deterministic
rng_state = torch.get_rng_state()
torch.manual_seed(f"{self.seed}{worker_id}{i}")
try:
next_batch = next(self.dataset)
i += 1
except StopIteration:
logger.info(
f"{self.dataset_name.name.lower()} has finished one epoch and is moving on to the next one."
f" (epoch={self.epoch} - rank={self.rank} - worker_id={worker_id})"
)
break
curr_mapped_batch = self.mapper(
next_batch,
prefix_seed=(self.seed, self.epoch, self.rank, worker_id, i),
)
torch.set_rng_state(rng_state)
keys = list(curr_mapped_batch.keys())
overflow_batch_keys = overflow_batch.keys()
# Check if overflow from previous batches is left, if yes, add it to the current batch
# Specifically, we should prepend this overflow batch so as it goes out first and
# current batch possibly becomes next overflow batch
if len(overflow_batch_keys) > 0:
if sorted(overflow_batch_keys) != sorted(keys):
raise ValueError(
"Overflow batch keys not equal to current keys. Make sure mapper is always returning"
" dictionary with the same keys. "
f"Overflow: {sorted(overflow_batch_keys)}, Mapping: {sorted(keys)}"
)
else:
mapped_batch = {}
if "pixel_values" in overflow_batch or "pixel_values" in curr_mapped_batch:
total_batch_size = overflow_batch["input_ids"].size(0) + curr_mapped_batch["input_ids"].size(0)
max_num_images = max(
overflow_batch["pixel_values"].size(1) if "pixel_values" in overflow_batch else 0,
curr_mapped_batch["pixel_values"].size(1) if "pixel_values" in curr_mapped_batch else 0,
)
max_height = max(
overflow_batch["pixel_values"].size(3) if "pixel_values" in overflow_batch else 0,
curr_mapped_batch["pixel_values"].size(3) if "pixel_values" in curr_mapped_batch else 0,
)
max_width = max(
overflow_batch["pixel_values"].size(4) if "pixel_values" in overflow_batch else 0,
curr_mapped_batch["pixel_values"].size(4) if "pixel_values" in curr_mapped_batch else 0,
)
padded_image_tensor = torch.zeros(total_batch_size, max_num_images, 3, max_height, max_width)
padded_pixel_attention_masks = torch.zeros(
total_batch_size, max_num_images, max_height, max_width, dtype=torch.bool
)
start = 0
for batch in [overflow_batch, curr_mapped_batch]:
if "pixel_values" not in batch:
continue
px = batch["pixel_values"]
px_attn_mask = batch["pixel_attention_mask"]
end = start + px.size(0)
padded_image_tensor[start:end, :, :, : px.size(3), : px.size(4)] = px
padded_pixel_attention_masks[start:end, :, : px.size(3), : px.size(4)] = px_attn_mask
start += px.size(0)
mapped_batch["pixel_values"] = padded_image_tensor.contiguous()
mapped_batch["pixel_attention_mask"] = padded_pixel_attention_masks.contiguous()
for key in keys:
if key in ["pixel_values", "pixel_attention_mask"]:
continue
mapped_batch[key] = torch.cat([overflow_batch[key], curr_mapped_batch[key]], dim=0)
overflow_batch = {}
else:
mapped_batch = curr_mapped_batch
first_key = keys[0]
mapped_batch_length = len(mapped_batch[first_key])
if self.shuffle_after_packing:
indices = list(range(mapped_batch_length))
self.rng.shuffle(indices)
for key in mapped_batch.keys():
mapped_batch[key] = mapped_batch[key][indices, ...]
if mapped_batch_length < self.batch_size:
# We need to add more data to this batch to make it of size `self.batch_size`
# Just setting mapped_batch to overflow_batch should be enough as the next iteration
# will add more data to it
overflow_batch = mapped_batch
else:
# Now, yield batches of size batch_size from the mapped batch
for key_idx in range(0, mapped_batch_length, self.batch_size):
# Set "reproducible" randomness
self.rng_seed = [self.seed, self.epoch, self.rank, worker_id, i, key_idx]
self.rng = np.random.default_rng(seed=self.rng_seed)
overflow_batch = {key: mapped_batch[key][key_idx : key_idx + self.batch_size] for key in keys}
if len(overflow_batch[first_key]) != self.batch_size:
# Last batch
break
else:
dataset_state = {
"worker_idx": worker_id,
"map_start_idx": i,
"last_key_idx": key_idx,
"previous_overflow_batch": {},
}
yield dummy_dataset_idx, self.dataset_name.name.lower(), dataset_state, overflow_batch
overflow_batch = {}