in optimum/habana/sentence_transformers/st_gaudi_trainer.py [0:0]
def get_train_dataloader(self) -> DataLoader:
"""
Returns the training [`~torch.utils.data.DataLoader`].
Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
training if necessary) otherwise.
Subclass and override this method if you want to inject some custom behavior.
"""
if self.train_dataset is None:
raise ValueError("Training requires specifying a train_dataset to the SentenceTransformerGaudiTrainer.")
train_dataset = self.train_dataset
data_collator = self.data_collator
generator = torch.Generator()
if self.args.seed:
generator.manual_seed(self.args.seed)
dataloader_params = {
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
"prefetch_factor": self.args.dataloader_prefetch_factor,
}
if isinstance(train_dataset, IterableDataset):
dataloader_params.update(
{
"batch_size": self.args.train_batch_size,
"drop_last": self.args.dataloader_drop_last,
}
)
if self.args.batch_sampler != BatchSamplers.BATCH_SAMPLER:
logger.warning("When using an IterableDataset, you cannot specify a batch sampler.")
elif isinstance(train_dataset, IterableDatasetDict):
raise ValueError(
"Sentence Transformers is not compatible with IterableDatasetDict. Please use a DatasetDict instead."
)
elif isinstance(train_dataset, DatasetDict):
for dataset in train_dataset.values():
if isinstance(dataset, IterableDataset):
raise ValueError(
"Sentence Transformers is not compatible with a DatasetDict containing an IterableDataset."
)
batch_samplers = [
self.get_batch_sampler(
dataset,
batch_size=self.args.train_batch_size,
drop_last=self.args.dataloader_drop_last,
valid_label_columns=data_collator.valid_label_columns,
generator=generator,
)
for dataset in train_dataset.values()
]
train_dataset = ConcatDataset(train_dataset.values())
batch_sampler = self.get_multi_dataset_batch_sampler(
dataset=train_dataset,
batch_samplers=batch_samplers,
generator=generator,
seed=self.args.seed,
)
dataloader_params["batch_sampler"] = batch_sampler
elif isinstance(train_dataset, Dataset):
batch_sampler = self.get_batch_sampler(
train_dataset,
batch_size=self.args.train_batch_size,
drop_last=self.args.dataloader_drop_last,
valid_label_columns=data_collator.valid_label_columns,
generator=generator,
)
dataloader_params["batch_sampler"] = batch_sampler
else:
raise ValueError(
"Unsupported `train_dataset` type. Use a Dataset, DatasetDict, or IterableDataset for training."
)
# If 'even_batches' is True, it will use the initial few samples to pad out the last sample. This can
# cause issues with multi-dataset training, so we want to set this to False.
# For evaluation, setting 'even_batches' to False results in hanging, so we keep it as True there.
self.accelerator.even_batches = False
self._train_dataloader = self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
return self._train_dataloader