in data.py [0:0]
def split_train_test(self, ratio=0.9, force=False):
"""
Split training and testing data in dialog level
:return:
"""
dialogs = self.dialogs["true_response"]
if "split" not in dialogs or force:
dialog_ids = list(dialogs["dialog_id"].unique())
if self.args.mode == "train":
tr_dv = random.sample(dialog_ids, int(len(dialog_ids) * ratio))
ts_dv = [i for i in range(len(dialog_ids)) if i not in tr_dv]
train_indices = [
i for i, row in dialogs.iterrows() if row["dialog_id"] in tr_dv
]
test_indices = [
i for i, row in dialogs.iterrows() if row["dialog_id"] in ts_dv
]
self.train_indices = train_indices
self.test_indices = test_indices
else:
self.train_indices = []
self.test_indices = list(range(len(dialogs)))
self.logbook.write_message_logs(
"Split done. Train rows : {}, Test rows : {}".format(
len(self.train_indices), len(self.test_indices)
)
)
for i, row in dialogs.iterrows():
split = "test"
if i in self.train_indices:
split = "train"
self.dialogs["true_response"].at[i, "split"] = split
file_path = os.path.join(
self.args.data_loc,
"{}_{}_{}.csv".format(
self.args.data_name, self.args.mode, "true_response"
),
)
self.dialogs["true_response"].to_csv(file_path)
else:
# load the split from the data
self.train_indices = []
self.test_indices = []
for i, row in self.dialogs["true_response"].iterrows():
if row["split"] == "train":
self.train_indices.append(i)
else:
self.test_indices.append(i)