in dialogue_personalization/utils/data_reader.py [0:0]
def get_data_loader(self,persona,batch_size, split, fold=-1):
dial_persona = self.type[split][persona]
if(len(dial_persona)==1 and split == "train"):
tr = []
val = []
for i in dial_persona:
for p in dial_persona[i]:
val.append(p)
tr.append(p)
else:
tr = []
val = []
if (split=="train"):
val_dial = randint(0,len(dial_persona)-1)
elif(fold != -1 and (split=="test" or split=="valid")):
val_dial = fold
else:
val_dial = len(dial_persona)-1
for i in dial_persona:
if(i == val_dial):
for p in dial_persona[i]:
val.append(p)
else:
for p in dial_persona[i]:
tr.append(p)
dataset_train = Dataset(tr,self.vocab)
data_loader_tr = torch.utils.data.DataLoader(dataset=dataset_train,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_fn)
dataset_valid = Dataset(val,self.vocab)
data_loader_val = torch.utils.data.DataLoader(dataset=dataset_valid,
batch_size=batch_size,
shuffle=False,
collate_fn=collate_fn)
return data_loader_tr, data_loader_val