def train()

in experiments/codes/experiment/checkpointable_multitask_experiment.py [0:0]


    def train(self, data, rule_world, epoch=0, report=True, task_idx=None):
        """
        Method to train
        :return:
        """
        mode = "train"
        train_nb = self.config.general.overfit
        epoch_loss = []
        epoch_acc = []
        self.composition_fn.train()
        self.representation_fn.train()
        num_batches = len(data[mode])
        num_batches_to_train = num_batches if train_nb == 0 else train_nb
        for batch_idx, batch in enumerate(data[mode]):
            if batch_idx >= num_batches_to_train:
                continue
            batch.to(self.config.general.device)
            rel_emb = self.representation_fn(batch)
            logits = self.composition_fn(batch, rel_emb)
            loss = self.composition_fn.loss(logits, batch.targets)
            for opt in self.optimizers:
                opt.zero_grad()
            loss.backward()
            for opt in self.optimizers:
                opt.step()
            epoch_loss.append(loss.cpu().detach().item())
            predictions, conf = self.composition_fn.predict(logits)
            epoch_acc.append(
                self.composition_fn.accuracy(predictions, batch.targets)
                .cpu()
                .detach()
                .item()
            )
        if report:
            rule_world_last = rule_world.split("/")[-1]
            metrics = {
                "mode": mode,
                "minibatch": self.train_step,
                "loss": np.mean(epoch_loss),
                "accuracy": np.mean(epoch_acc),
                "epoch": epoch,
                "rule_world": rule_world,
            }
            if task_idx:
                metrics["task_idx"] = task_idx
            self.logbook.write_metric_logs(metrics)
            epoch_loss = []
            epoch_acc = []
        self.train_step += 1