in dialogue_personalization/utils/data_reader.py [0:0]
def get_balanced_loader(self,persona,batch_size,split, fold=-1, dial_num=1):
dial_persona = self.type[split][persona]
if len(dial_persona)==1:
raise ValueError("persona have less than two dialogs")
tr = []
val = []
if (split=="train" or split=="valid"):
val_dial = 0
tr_dial =0
while val_dial==tr_dial:
val_dial = randint(0,len(dial_persona)-1)
tr_dial = randint(0,len(dial_persona)-1)
for p in dial_persona[val_dial]:
val.append(p)
for p in dial_persona[tr_dial]:
tr.append(p)
elif(fold != -1 and (split=="test")):
val_dial = fold
else:
val_dial = len(dial_persona)-1
if (split=="test"):
for i in dial_persona:
if(i == val_dial):
for p in dial_persona[i]:
val.append(p)
else:
if dial_num==0:
continue
for p in dial_persona[i]:
tr.append(p)
dial_num-=1
dataset_train = Dataset(tr,self.vocab)
data_loader_tr = torch.utils.data.DataLoader(dataset=dataset_train,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_fn)
dataset_valid = Dataset(val,self.vocab)
data_loader_val = torch.utils.data.DataLoader(dataset=dataset_valid,
batch_size=batch_size,
shuffle=False,
collate_fn=collate_fn)
return data_loader_tr, data_loader_val