def preflight_steps()

in codes/models.py [0:0]


    def preflight_steps(self):
        """
        Extract all training BERT embeddings and train pca
        Do it only if we do not have a saved file
        :return:
        """
        if not self.hparams.learn_down and not self.hparams.fix_down:
            self.logbook.write_message_logs(
                "Checking pca file in ... {}".format(self.hparams.pca_file)
            )
            if not self.down_model:
                if os.path.exists(self.hparams.pca_file) and os.path.isfile(
                    self.hparams.pca_file
                ):
                    self.logbook.write_message_logs(
                        "Loading PCA model from {}".format(self.hparams.pca_file)
                    )
                    data_dump = pkl.load(open(self.hparams.pca_file, "rb"))
                    self.down_model = data_dump["pca"]
                else:
                    self.logbook.write_message_logs(
                        "Not found. Proceeding to extract and train..."
                    )
                    self.down_model = IncrementalPCA(
                        n_components=self.hparams.down_dim, whiten=True
                    )
                    # extract and save embeddings
                    train_loader = self.get_dataloader(mode="train")
                    all_vecs = []
                    self.logbook.write_message_logs("Extracting embeddings ...")
                    pb = tqdm(total=len(train_loader))
                    for bi, batch in enumerate(train_loader):
                        (
                            inp,
                            inp_len,
                            inp_dial_len,
                            y_true,
                            y_true_len,
                            y_false,
                            y_false_len,
                        ) = batch
                        if inp.size(0) < self.hparams.batch_size:
                            continue
                        with torch.no_grad():
                            batch, num_dials, num_words = inp.shape
                            inp = inp.view(-1, num_words).to(self.hparams.device)
                            inp_dial_len = inp_dial_len.to(self.hparams.device)
                            inp_vec = self.extract_sentence_bert(inp, inp_dial_len)
                            inp_vec = inp_vec.view(batch, num_dials, -1)  # B x D x dim
                            inp_vec = (
                                inp_vec.view(-1, inp_vec.size(2)).to("cpu").numpy()
                            )  # (B x D) x dim
                            self.down_model.partial_fit(inp_vec)
                        del inp
                        del inp_len
                        del inp_vec
                        del y_true
                        del y_false
                        # temporary solution...
                        torch.cuda.empty_cache()
                        pb.update(1)
                        # if bi == 100:
                        #     break
                    pb.close()
                    self.logbook.write_message_logs(
                        "Saving PCA model in {}".format(self.hparams.pca_file)
                    )
                    pkl.dump(
                        {"pca": self.down_model}, open(self.hparams.pca_file, "wb")
                    )