def get_balanced_loader()

in dialogue_personalization/utils/data_reader.py [0:0]


    def get_balanced_loader(self,persona,batch_size,split, fold=-1, dial_num=1):
        dial_persona = self.type[split][persona]
        if len(dial_persona)==1:
            raise ValueError("persona have less than two dialogs")
        tr = []
        val = []
        if (split=="train" or split=="valid"):
            val_dial = 0
            tr_dial =0
            while val_dial==tr_dial:
                val_dial = randint(0,len(dial_persona)-1)
                tr_dial = randint(0,len(dial_persona)-1)
            for p in dial_persona[val_dial]:
                val.append(p)
            for p in dial_persona[tr_dial]:
                tr.append(p)
        elif(fold != -1 and (split=="test")):
            val_dial = fold
        else:
            val_dial = len(dial_persona)-1
        if (split=="test"):
            for i in dial_persona:
                if(i == val_dial):
                    for p in dial_persona[i]:
                        val.append(p)
                else:
                    if dial_num==0:
                        continue
                    for p in dial_persona[i]:
                        tr.append(p)
                    dial_num-=1
            
        dataset_train = Dataset(tr,self.vocab)
        data_loader_tr = torch.utils.data.DataLoader(dataset=dataset_train,
                                                batch_size=batch_size,
                                                shuffle=True,
                                                collate_fn=collate_fn)

        dataset_valid = Dataset(val,self.vocab)
        data_loader_val = torch.utils.data.DataLoader(dataset=dataset_valid,
                                                batch_size=batch_size,
                                                shuffle=False,
                                                collate_fn=collate_fn)
        
        return data_loader_tr, data_loader_val