in experiments/codes/experiment/checkpointable_multitask_experiment.py [0:0]
def train(self, data, rule_world, epoch=0, report=True, task_idx=None):
"""
Method to train
:return:
"""
mode = "train"
train_nb = self.config.general.overfit
epoch_loss = []
epoch_acc = []
self.composition_fn.train()
self.representation_fn.train()
num_batches = len(data[mode])
num_batches_to_train = num_batches if train_nb == 0 else train_nb
for batch_idx, batch in enumerate(data[mode]):
if batch_idx >= num_batches_to_train:
continue
batch.to(self.config.general.device)
rel_emb = self.representation_fn(batch)
logits = self.composition_fn(batch, rel_emb)
loss = self.composition_fn.loss(logits, batch.targets)
for opt in self.optimizers:
opt.zero_grad()
loss.backward()
for opt in self.optimizers:
opt.step()
epoch_loss.append(loss.cpu().detach().item())
predictions, conf = self.composition_fn.predict(logits)
epoch_acc.append(
self.composition_fn.accuracy(predictions, batch.targets)
.cpu()
.detach()
.item()
)
if report:
rule_world_last = rule_world.split("/")[-1]
metrics = {
"mode": mode,
"minibatch": self.train_step,
"loss": np.mean(epoch_loss),
"accuracy": np.mean(epoch_acc),
"epoch": epoch,
"rule_world": rule_world,
}
if task_idx:
metrics["task_idx"] = task_idx
self.logbook.write_metric_logs(metrics)
epoch_loss = []
epoch_acc = []
self.train_step += 1