in backup/train.py [0:0]
def train_epoch(self, epoch=0):
"""
Train a binary predictor
:return:
"""
self.model.train()
# X, Y, Y_hat = self.data.prepare_data(self.data.train_indices, self.args.vector_mode)
loss_a = []
t_loss_a = []
f_loss_a = []
# num_batches = len(list(range(0, len(X), self.args.batch_size)))
dl = self.data.get_dataloader(mode="train", epoch=epoch)
for i, batch in enumerate(dl):
# for i in range(0, len(X), self.args.batch_size):
# inp_vec, inp_len = batchify(X[i:i+self.args.batch_size], self.args.vector_mode)
# outp_vec, outp_len = batchify(Y[i:i+self.args.batch_size], self.args.vector_mode)
inp, inp_len, y_true, y_false = batch
# inp_vec = [self.data.pca_predict([self.extract_sentence_bert(batch)])[0] for batch in inp]
# inp_vec, _ = batchify(inp_vec, vector_mode=True)
# inp_vec = inp_vec.to(self.device)
# y_true = self.data.pca_predict([self.extract_sentence_bert(y_true)])[0]
# y_true = torch.stack(y_true, dim=0)
# y_true = y_true.to(self.device)
# y_false = self.data.pca_predict([self.extract_sentence_bert(y_false)])[0]
# y_false = torch.stack(y_false, dim=0)
# y_false = y_false.to(self.device)
pred_true = self.model(inp, inp_len, y_true)
pred_false = self.model(inp, inp_len, y_false)
# import pdb; pdb.set_trace()
t_loss = self.loss_fn(
pred_true, torch.ones(pred_true.size(0), 1).to(self.device)
)
f_loss = self.loss_fn(
pred_false, torch.zeros(pred_false.size(0), 1).to(self.device)
)
loss = t_loss + f_loss
self.optimizer.zero_grad()
loss.backward()
loss_a.append(loss.item())
t_loss_a.append(t_loss.item())
f_loss_a.append(f_loss.item())
if i % self.args.log_interval == 0:
metrics = {
"mode": "train",
"minibatch": self.train_step,
"loss": np.mean(loss_a),
"true_loss": np.mean(t_loss_a),
"false_loss": np.mean(f_loss_a),
"epoch": epoch,
}
self.train_step += 1
loss_a = []
t_loss_a = []
f_loss_a = []
self.logbook.write_metric_logs(metrics)
self.optimizer.step()
# post epoch
metrics = {
"mode": "train",
"minibatch": self.train_step,
"loss": np.mean(loss_a),
"true_loss": np.mean(t_loss_a),
"false_loss": np.mean(f_loss_a),
"epoch": epoch,
}
self.train_step += 1
self.logbook.write_metric_logs(metrics)