in codes/net.py [0:0]
def get_dataloader(self, mode="train", datamode="train"):
try:
if datamode == "test" and mode == "train":
raise AssertionError("datamode test does not have training indices")
hparams = copy.deepcopy(self.copy_hparams)
hparams.mode = datamode
data = None
# import ipdb; ipdb.set_trace()
if datamode == "train":
if not self.train_data:
self.logbook.write_message_logs("init loading data for training")
self.train_data = ParlAIExtractor(hparams, self.logbook)
self.hparams = self.train_data.args
self.train_data.load()
self.preflight_steps()
# if self.hparams.downsample:
# self.logbook.write_message_logs("Downsampling to {}".format(
# self.hparams.down_dim))
# self.downsample()
# self.data.clear_emb()
data = self.train_data
elif datamode == "test":
self.logbook.write_message_logs("init loading data for testing")
hparams.mode = "test"
self.test_data = ParlAIExtractor(hparams, self.logbook)
self.hparams = self.test_data.args
self.test_data.args.load_model_responses = False
self.test_data.load()
self.preflight_steps()
data = self.test_data
if mode == "train":
indices = data.train_indices
elif mode == "test":
indices = data.test_indices
else:
raise NotImplementedError("get_dataloader mode not implemented")
ddl = DialogDataLoader(
self.hparams,
data,
indices=indices,
bert_input=self.bert_input,
is_transition_fn=self.is_transition_fn,
)
## ddl = DialogDiskDataLoader(self.hparams, mode, epoch)
dist_sampler = None
batch_size = self.hparams.batch_size
# try:
# if self.on_gpu:
if self.use_ddp:
dist_sampler = DistributedSampler(ddl, rank=self.trainer.proc_rank)
batch_size = self.hparams.batch_size // self.trainer.world_size
print(batch_size)
# except Exception as e:
# pass
if self.hparams.train_mode in ["ref_score", "nce"]:
return DataLoader(
ddl,
collate_fn=self.collate_fn,
batch_size=batch_size,
sampler=dist_sampler,
num_workers=self.hparams.num_workers,
)
else:
return DataLoader(
ddl,
collate_fn=self.context_collate_fn,
batch_size=batch_size,
sampler=dist_sampler,
num_workers=self.hparams.num_workers,
)
except Exception as e:
print(e)