def prepare_data()

in data.py [0:0]


    def prepare_data(self, mode, online=False):
        """
        Pre-process data and save them in disk
        :param mode: train or test
        :param vector: return the bert sentence vector of dialogs, else return raw words
        :param online: if vector, and if online, return vectors directly querying bert
        :return:
        """
        vector = self.args.vector_mode
        if mode == "train":
            indices = self.train_indices
        elif mode == "test":
            indices = self.test_indices
        else:
            raise NotImplementedError("{} not implemented".format(mode))
        if vector:
            if online:
                dialogs = [self.dialogs[di] for di in indices]
                dial_vecs = [self.extract_sentence_bert(dial) for dial in dialogs]
            else:
                dial_vecs = [self.dial_vecs[di] for di in indices]
        else:
            dial_vecs = [self.dialog_tokens[di] for di in indices]
            dial_vecs = [
                [[self.get_word_id(w) for w in utt] for utt in dl] for dl in dial_vecs
            ]
        dialogs = [self.dialogs[di] for di in indices]
        cd = CorruptDialog(self.args, self, False, bert_tokenize=True)
        # save individual epoch data in file
        pbe = tqdm(total=self.args.epochs)
        for epoch in range(self.args.epochs):
            X = []
            Y = []
            Y_hat = []
            pb = tqdm(total=len(dial_vecs))
            for di, dial in enumerate(dial_vecs):
                dialog_id = indices[di]
                for i in range(1, len(dial)):
                    inp = dial[0:i]
                    outp = dial[i]
                    if not vector:
                        # flatten into one sentence
                        inp = [w for utt in inp for w in utt]
                    X.append(inp)
                    Y.append([outp])
                    if self.args.corrupt_type == "rand_utt":
                        sc = cd.random_clean(dialog_id=dialog_id)
                    elif self.args.corrupt_type == "drop":
                        sc = cd.random_drop(
                            self.dialogs[dialog_id][i], drop=self.args.drop_per
                        )
                    elif self.args.corrupt_type == "shuffle":
                        sc = cd.change_word_order(self.dialogs[dialog_id][i])
                    elif self.args.corrupt_type in ["model_true", "model_false"]:
                        sc = cd.get_nce_semantics(dialog_id, i)
                    else:
                        raise NotImplementedError(
                            "args.corrupt_type {} not implemented".format(
                                self.args.corrupt_type
                            )
                        )
                    Y_hat.append(sc)
                pb.update(1)
            pb.close()
            # extract Y_hat from BERT
            Y_hat_h = []
            bs = 32
            self.logbook.write_message_logs("Extracting negative samples from BERT")
            pb = tqdm(total=len(range(0, len(Y_hat), bs)))
            for yi in range(0, len(Y_hat), bs):
                Y_hat_h.append(
                    self.pca_predict(
                        [
                            list(
                                self.extract_sentence_bert(
                                    Y_hat[yi : yi + bs], tokenize=False
                                )
                            )
                        ]
                    )[0]
                )
                pb.update(1)
            pb.close()
            epoch_data = [X, Y, Y_hat_h]
            pkl.dump(
                epoch_data,
                open(
                    os.path.join(
                        self.args.exp_data_folder, "{}_epoch_{}.pkl".format(mode, epoch)
                    ),
                    "wb",
                ),
            )
            pbe.update(1)
        pbe.close()