in training/dataset/sam2_datasets.py [0:0]
def get_loader(self, epoch) -> Iterable:
dataloaders = []
for d_idx, (dataset, batch_size) in enumerate(
zip(self.datasets, self.batch_sizes)
):
if self.phases_per_epoch > 1:
# Major epoch that looops over entire dataset
# len(main_epoch) == phases_per_epoch * len(epoch)
main_epoch = epoch // self.phases_per_epoch
# Phase with in the main epoch
local_phase = epoch % self.phases_per_epoch
# Start of new data-epoch or job is resumed after preemtion.
if local_phase == 0 or self.chunks[d_idx] is None:
# set seed for dataset epoch
# If using RepeatFactorWrapper, this step currectly re-samples indices before chunking.
self._set_dataset_epoch(dataset, main_epoch)
# Separate random generator for subset sampling
g = torch.Generator()
g.manual_seed(main_epoch)
self.chunks[d_idx] = torch.chunk(
torch.randperm(len(dataset), generator=g),
self.phases_per_epoch,
)
dataset = Subset(dataset, self.chunks[d_idx][local_phase])
else:
self._set_dataset_epoch(dataset, epoch)
sampler = DistributedSampler(dataset, shuffle=self.shuffle)
sampler.set_epoch(epoch)
batch_sampler = BatchSampler(sampler, batch_size, drop_last=self.drop_last)
dataloaders.append(
DataLoader(
dataset,
num_workers=self.num_workers,
pin_memory=self.pin_memory,
batch_sampler=batch_sampler,
collate_fn=self.collate_fn,
worker_init_fn=self.worker_init_fn,
)
)
return MixedDataLoader(dataloaders, self.dataset_prob)