in src/hyperpod_nemo_adapter/collections/data/vision_dataset.py [0:0]
def get_dataloader_kwargs(train_config, dataset, dataset_processer, mode):
kwargs = {}
batch_size = train_config.batch_size_training if mode == "train" else train_config.val_batch_size
if train_config.batching_strategy == "padding":
if train_config.enable_fsdp:
kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
dataset,
batch_size=batch_size,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
shuffle=mode == "train",
)
else:
kwargs["batch_sampler"] = LengthBasedBatchSampler(
dataset, batch_size, drop_last=True, shuffle=mode == "train"
)
kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
elif train_config.batching_strategy == "packing":
if train_config.enable_fsdp:
kwargs["sampler"] = DistributedSampler(
dataset,
rank=dist.get_rank(),
num_replicas=dist.get_world_size(),
shuffle=mode == "train",
drop_last=True,
)
kwargs["batch_size"] = batch_size
kwargs["drop_last"] = True
kwargs["collate_fn"] = default_data_collator
else:
raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
return kwargs