in empchat/datasets/loader.py [0:0]
def build_train_dataloader(self, epoch_id):
if self.dataset_name == "empchat":
dataset = EmpDataset(
"train",
self.temp_dict,
data_folder=self.opt.empchat_folder,
maxlen=self.opt.max_sent_len,
reactonly=self.opt.reactonly,
history_len=self.opt.max_hist_len,
fasttext=self.opt.fasttext,
fasttext_type=self.opt.fasttext_type,
fasttext_path=self.opt.fasttext_path,
)
return DataLoader(
dataset,
batch_size=self.opt.batch_size,
shuffle=not self.opt.no_shuffle,
num_workers=0,
collate_fn=self.batchify,
pin_memory=self.opt.cuda,
)
elif self.dataset_name == "reddit":
dataset = self.build_reddit_dataset(epoch_id % 999)
return DataLoader(
dataset,
batch_size=self.opt.batch_size,
shuffle=not self.opt.no_shuffle,
num_workers=8,
collate_fn=self.batchify,
pin_memory=self.opt.cuda,
)
else:
raise ValueError("Dataset name unrecognized!")