def evaluate_representations()

in experiments/codes/experiment/inference.py [0:0]


    def evaluate_representations(self, mode="test", k=0, ale_mode=False):
        """Evaluate and store the representations
        """
        # if ale_mode:
        #     test_world = self.test_world + "_ale"
        # else:
        #     test_world = self.test_world
        if "," in self.test_world:
            worlds = self.test_world.split(",")
        else:
            worlds = self.test_world
        task_loss = []
        task_acc = []
        rep_emb_worlds = {}
        for test_world in worlds:
            if test_world not in self.dataloaders["test"]:
                continue
            rep_emb_worlds[test_world] = {}
            self.test_data = self.dataloaders["test"][test_world]
            self.composition_fn.eval()
            self.representation_fn.eval()
            epoch_loss = []
            epoch_acc = []
            for batch_idx, batch in enumerate(self.test_data[mode]):
                batch.to(self.config.general.device)
                rel_emb = self.representation_fn(batch)
                logits = self.composition_fn(batch, rel_emb)
                loss = self.composition_fn.loss(logits, batch.targets)
                predictions, conf = self.composition_fn.predict(logits)
                epoch_loss.append(loss.cpu().detach().item())
                epoch_acc.append(
                    self.composition_fn.accuracy(predictions, batch.targets)
                    .cpu()
                    .detach()
                    .item()
                )
                rep_emb_worlds[test_world][batch_idx] = rel_emb[0].to("cpu").numpy()
            task_loss.append(np.mean(epoch_loss))
            task_acc.append(np.mean(epoch_acc))

        metrics = {
            "mode": mode,
            "minibatch": self.train_step,
            "epoch": self.epoch,
            "accuracy": np.mean(task_acc),
            "acc_std": np.std(task_acc),
            "loss": np.mean(task_loss),
            "k": k,
            "top_mode": "test",
            "rule_world": self.test_world,
        }
        # self.logbook.write_metric_logs(metrics)
        return metrics, rep_emb_worlds