def get_dataloader()

in codes/net.py [0:0]


    def get_dataloader(self, mode="train", datamode="train"):
        try:
            if datamode == "test" and mode == "train":
                raise AssertionError("datamode test does not have training indices")
            hparams = copy.deepcopy(self.copy_hparams)
            hparams.mode = datamode
            data = None
            # import ipdb; ipdb.set_trace()
            if datamode == "train":
                if not self.train_data:
                    self.logbook.write_message_logs("init loading data for training")
                    self.train_data = ParlAIExtractor(hparams, self.logbook)
                    self.hparams = self.train_data.args
                    self.train_data.load()
                    self.preflight_steps()
                    # if self.hparams.downsample:
                    #     self.logbook.write_message_logs("Downsampling to {}".format(
                    #         self.hparams.down_dim))
                    #     self.downsample()
                    #     self.data.clear_emb()
                data = self.train_data
            elif datamode == "test":
                self.logbook.write_message_logs("init loading data for testing")
                hparams.mode = "test"
                self.test_data = ParlAIExtractor(hparams, self.logbook)
                self.hparams = self.test_data.args
                self.test_data.args.load_model_responses = False
                self.test_data.load()
                self.preflight_steps()
                data = self.test_data
            if mode == "train":
                indices = data.train_indices
            elif mode == "test":
                indices = data.test_indices
            else:
                raise NotImplementedError("get_dataloader mode not implemented")
            ddl = DialogDataLoader(
                self.hparams,
                data,
                indices=indices,
                bert_input=self.bert_input,
                is_transition_fn=self.is_transition_fn,
            )
            ## ddl = DialogDiskDataLoader(self.hparams, mode, epoch)
            dist_sampler = None
            batch_size = self.hparams.batch_size
            # try:
            # if self.on_gpu:
            if self.use_ddp:
                dist_sampler = DistributedSampler(ddl, rank=self.trainer.proc_rank)
                batch_size = self.hparams.batch_size // self.trainer.world_size
            print(batch_size)
            # except Exception as e:
            #     pass
            if self.hparams.train_mode in ["ref_score", "nce"]:
                return DataLoader(
                    ddl,
                    collate_fn=self.collate_fn,
                    batch_size=batch_size,
                    sampler=dist_sampler,
                    num_workers=self.hparams.num_workers,
                )
            else:
                return DataLoader(
                    ddl,
                    collate_fn=self.context_collate_fn,
                    batch_size=batch_size,
                    sampler=dist_sampler,
                    num_workers=self.hparams.num_workers,
                )
        except Exception as e:
            print(e)