in backup/baselines/ruber/train.py [0:0]
def train_epoch(self, epoch=0):
"""
Train the unreferenced metric
:return:
"""
self.model.train()
X, Y, Y_hat = self.data.prepare_data(self.data.train_indices, False)
loss_a = []
t_loss_a = []
f_loss_a = []
num_batches = len(list(range(0, len(X), self.args.batch_size)))
for i in range(0, len(X), self.args.batch_size):
inp_vec, inp_len = batchify(X[i : i + self.args.batch_size], False)
outp_vec, outp_len = batchify(Y[i : i + self.args.batch_size], False)
inp_vec = inp_vec.to(self.device)
outp_vec = outp_vec.to(self.device)
diff_true = self.model(inp_vec, inp_len, outp_vec, outp_len)
y_false, y_len = batchify(Y_hat[i : i + self.args.batch_size], False)
y_false = y_false.to(self.device)
diff_false = self.model(inp_vec, inp_len, y_false, y_len)
# import pdb; pdb.set_trace()
loss = torch.clamp(
torch.ones_like(diff_true) * self.args.margin - diff_true + diff_false,
min=0.0,
)
loss = loss.mean()
self.optimizer.zero_grad()
loss.backward()
loss_a.append(loss.item())
t_loss_a.append(diff_true.mean().item())
f_loss_a.append(diff_false.mean().item())
if i % self.args.log_interval == 0 or (i + 1) > (
len(X) - self.args.batch_size
):
print(i)
metrics = {
"mode": "train",
"minibatch": self.train_step,
"loss": np.mean(loss_a),
"true_loss": np.mean(t_loss_a),
"false_loss": np.mean(f_loss_a),
"epoch": epoch,
}
self.train_step += 1
loss_a = []
t_loss_a = []
f_loss_a = []
self.logbook.write_metric_logs(metrics)
self.optimizer.step()