in data.py [0:0]
def prepare_data(self, mode, online=False):
"""
Pre-process data and save them in disk
:param mode: train or test
:param vector: return the bert sentence vector of dialogs, else return raw words
:param online: if vector, and if online, return vectors directly querying bert
:return:
"""
vector = self.args.vector_mode
if mode == "train":
indices = self.train_indices
elif mode == "test":
indices = self.test_indices
else:
raise NotImplementedError("{} not implemented".format(mode))
if vector:
if online:
dialogs = [self.dialogs[di] for di in indices]
dial_vecs = [self.extract_sentence_bert(dial) for dial in dialogs]
else:
dial_vecs = [self.dial_vecs[di] for di in indices]
else:
dial_vecs = [self.dialog_tokens[di] for di in indices]
dial_vecs = [
[[self.get_word_id(w) for w in utt] for utt in dl] for dl in dial_vecs
]
dialogs = [self.dialogs[di] for di in indices]
cd = CorruptDialog(self.args, self, False, bert_tokenize=True)
# save individual epoch data in file
pbe = tqdm(total=self.args.epochs)
for epoch in range(self.args.epochs):
X = []
Y = []
Y_hat = []
pb = tqdm(total=len(dial_vecs))
for di, dial in enumerate(dial_vecs):
dialog_id = indices[di]
for i in range(1, len(dial)):
inp = dial[0:i]
outp = dial[i]
if not vector:
# flatten into one sentence
inp = [w for utt in inp for w in utt]
X.append(inp)
Y.append([outp])
if self.args.corrupt_type == "rand_utt":
sc = cd.random_clean(dialog_id=dialog_id)
elif self.args.corrupt_type == "drop":
sc = cd.random_drop(
self.dialogs[dialog_id][i], drop=self.args.drop_per
)
elif self.args.corrupt_type == "shuffle":
sc = cd.change_word_order(self.dialogs[dialog_id][i])
elif self.args.corrupt_type in ["model_true", "model_false"]:
sc = cd.get_nce_semantics(dialog_id, i)
else:
raise NotImplementedError(
"args.corrupt_type {} not implemented".format(
self.args.corrupt_type
)
)
Y_hat.append(sc)
pb.update(1)
pb.close()
# extract Y_hat from BERT
Y_hat_h = []
bs = 32
self.logbook.write_message_logs("Extracting negative samples from BERT")
pb = tqdm(total=len(range(0, len(Y_hat), bs)))
for yi in range(0, len(Y_hat), bs):
Y_hat_h.append(
self.pca_predict(
[
list(
self.extract_sentence_bert(
Y_hat[yi : yi + bs], tokenize=False
)
)
]
)[0]
)
pb.update(1)
pb.close()
epoch_data = [X, Y, Y_hat_h]
pkl.dump(
epoch_data,
open(
os.path.join(
self.args.exp_data_folder, "{}_epoch_{}.pkl".format(mode, epoch)
),
"wb",
),
)
pbe.update(1)
pbe.close()