in optimum/habana/sentence_transformers/st_gaudi_trainer.py [0:0]
def get_test_dataloader(self, test_dataset: Union[Dataset, DatasetDict, IterableDataset]) -> DataLoader:
"""
Returns the training [`~torch.utils.data.DataLoader`].
Subclass and override this method if you want to inject some custom behavior.
Args:
test_dataset (`torch.utils.data.Dataset`, *optional*):
The test dataset to use. If it is a [`~datasets.Dataset`], columns not accepted by the
`model.forward()` method are automatically removed. It must implement `__len__`.
"""
data_collator = self.data_collator
generator = torch.Generator()
if self.args.seed:
generator.manual_seed(self.args.seed)
dataloader_params = {
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
"prefetch_factor": self.args.dataloader_prefetch_factor,
}
if isinstance(test_dataset, IterableDataset):
dataloader_params.update(
{
"batch_size": self.args.eval_batch_size,
"drop_last": self.args.dataloader_drop_last,
}
)
elif isinstance(test_dataset, IterableDatasetDict):
raise ValueError(
"Sentence Transformers is not compatible with IterableDatasetDict. Please use a DatasetDict instead."
)
elif isinstance(test_dataset, DatasetDict):
for dataset in test_dataset.values():
if isinstance(dataset, IterableDataset):
raise ValueError(
"Sentence Transformers is not compatible with a DatasetDict containing an IterableDataset."
)
batch_samplers = [
self.get_batch_sampler(
dataset,
batch_size=self.args.eval_batch_size,
drop_last=self.args.dataloader_drop_last,
valid_label_columns=data_collator.valid_label_columns,
generator=generator,
)
for dataset in test_dataset.values()
]
test_dataset = ConcatDataset(test_dataset.values())
batch_sampler = self.get_multi_dataset_batch_sampler(
dataset=test_dataset,
batch_samplers=batch_samplers,
generator=generator,
seed=self.args.seed,
)
dataloader_params["batch_sampler"] = batch_sampler
elif isinstance(test_dataset, Dataset):
batch_sampler = self.get_batch_sampler(
test_dataset,
batch_size=self.args.eval_batch_size,
drop_last=self.args.dataloader_drop_last,
valid_label_columns=data_collator.valid_label_columns,
generator=generator,
)
dataloader_params["batch_sampler"] = batch_sampler
else:
raise ValueError(
"Unsupported `test_dataset` type. Use a Dataset, DatasetDict, or IterableDataset for testing."
)
# If 'even_batches' is True, it will use the initial few samples to pad out the last sample. This can
# cause issues with multi-dataset training, so we want to set this to False during training.
# For evaluation, setting 'even_batches' to False results in hanging, so we keep it as True here.
self.accelerator.even_batches = True
return self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params))