def get_data_loader()

in dialogue_personalization/utils/data_reader.py [0:0]


    def get_data_loader(self,persona,batch_size, split, fold=-1):
        dial_persona = self.type[split][persona]
        if(len(dial_persona)==1 and split == "train"):
            tr = []
            val = []
            for i in dial_persona:
                for p in dial_persona[i]:
                    val.append(p)
                    tr.append(p)
        else:
            tr = []
            val = []
            if (split=="train"):
                val_dial = randint(0,len(dial_persona)-1)
            elif(fold != -1 and (split=="test" or split=="valid")):
                val_dial = fold
            else:
                val_dial = len(dial_persona)-1
            for i in dial_persona:
                if(i == val_dial):
                    for p in dial_persona[i]:
                        val.append(p)
                else:
                    for p in dial_persona[i]:
                        tr.append(p)

        dataset_train = Dataset(tr,self.vocab)
        data_loader_tr = torch.utils.data.DataLoader(dataset=dataset_train,
                                                batch_size=batch_size,
                                                shuffle=True,
                                                collate_fn=collate_fn)

        dataset_valid = Dataset(val,self.vocab)
        data_loader_val = torch.utils.data.DataLoader(dataset=dataset_valid,
                                                batch_size=batch_size,
                                                shuffle=False,
                                                collate_fn=collate_fn)
        
        return data_loader_tr, data_loader_val