in codes/models.py [0:0]
def preflight_steps(self):
"""
Extract all training BERT embeddings and train pca
Do it only if we do not have a saved file
:return:
"""
if not self.hparams.learn_down and not self.hparams.fix_down:
self.logbook.write_message_logs(
"Checking pca file in ... {}".format(self.hparams.pca_file)
)
if not self.down_model:
if os.path.exists(self.hparams.pca_file) and os.path.isfile(
self.hparams.pca_file
):
self.logbook.write_message_logs(
"Loading PCA model from {}".format(self.hparams.pca_file)
)
data_dump = pkl.load(open(self.hparams.pca_file, "rb"))
self.down_model = data_dump["pca"]
else:
self.logbook.write_message_logs(
"Not found. Proceeding to extract and train..."
)
self.down_model = IncrementalPCA(
n_components=self.hparams.down_dim, whiten=True
)
# extract and save embeddings
train_loader = self.get_dataloader(mode="train")
all_vecs = []
self.logbook.write_message_logs("Extracting embeddings ...")
pb = tqdm(total=len(train_loader))
for bi, batch in enumerate(train_loader):
(
inp,
inp_len,
inp_dial_len,
y_true,
y_true_len,
y_false,
y_false_len,
) = batch
if inp.size(0) < self.hparams.batch_size:
continue
with torch.no_grad():
batch, num_dials, num_words = inp.shape
inp = inp.view(-1, num_words).to(self.hparams.device)
inp_dial_len = inp_dial_len.to(self.hparams.device)
inp_vec = self.extract_sentence_bert(inp, inp_dial_len)
inp_vec = inp_vec.view(batch, num_dials, -1) # B x D x dim
inp_vec = (
inp_vec.view(-1, inp_vec.size(2)).to("cpu").numpy()
) # (B x D) x dim
self.down_model.partial_fit(inp_vec)
del inp
del inp_len
del inp_vec
del y_true
del y_false
# temporary solution...
torch.cuda.empty_cache()
pb.update(1)
# if bi == 100:
# break
pb.close()
self.logbook.write_message_logs(
"Saving PCA model in {}".format(self.hparams.pca_file)
)
pkl.dump(
{"pca": self.down_model}, open(self.hparams.pca_file, "wb")
)