in src/hyperpod_nemo_adapter/collections/data/hf_image_data_module.py [0:0]
def get_dataloader(self, training=True):
assert (
self.cfg.model.hf_model_name_or_path is not None
), "Currently HuggingFaceVisionDataModule only support to run with hf_model_name_or_path"
if self.cfg.model.data.train_dir is None:
_logger.info(f"train_dir is not provided, using ocrvqa data for testing")
if training:
self._train_ds = get_custom_dataset("train")
else:
self._validation_ds = get_custom_dataset("test")
else:
input_path = self.cfg.model.data.train_dir if training else self.cfg.model.data.val_dir
if not input_path:
return None
dataset = HuggingFacePretrainingVisionDataset(
input_path=input_path, partition="train" if training else "val"
)
if training:
self._train_ds = dataset.dataset
else:
self._validation_ds = dataset.dataset
token = self.cfg.model.get("hf_access_token", None)
processor = AutoProcessor.from_pretrained(self.cfg.model.hf_model_name_or_path, token=token)
processor.tokenizer.padding_side = "right"
data_collator = OCRVQADataCollator(processor)
self.collate_fn = data_collator
if training:
return self._build_dataloader(self._train_ds, batch_size=self.cfg.model.train_batch_size)
else:
return self._build_dataloader(self._validation_ds, batch_size=self.cfg.model.val_batch_size)