in experiments/codes/experiment/inference.py [0:0]
def evaluate_representations(self, mode="test", k=0, ale_mode=False):
"""Evaluate and store the representations
"""
# if ale_mode:
# test_world = self.test_world + "_ale"
# else:
# test_world = self.test_world
if "," in self.test_world:
worlds = self.test_world.split(",")
else:
worlds = self.test_world
task_loss = []
task_acc = []
rep_emb_worlds = {}
for test_world in worlds:
if test_world not in self.dataloaders["test"]:
continue
rep_emb_worlds[test_world] = {}
self.test_data = self.dataloaders["test"][test_world]
self.composition_fn.eval()
self.representation_fn.eval()
epoch_loss = []
epoch_acc = []
for batch_idx, batch in enumerate(self.test_data[mode]):
batch.to(self.config.general.device)
rel_emb = self.representation_fn(batch)
logits = self.composition_fn(batch, rel_emb)
loss = self.composition_fn.loss(logits, batch.targets)
predictions, conf = self.composition_fn.predict(logits)
epoch_loss.append(loss.cpu().detach().item())
epoch_acc.append(
self.composition_fn.accuracy(predictions, batch.targets)
.cpu()
.detach()
.item()
)
rep_emb_worlds[test_world][batch_idx] = rel_emb[0].to("cpu").numpy()
task_loss.append(np.mean(epoch_loss))
task_acc.append(np.mean(epoch_acc))
metrics = {
"mode": mode,
"minibatch": self.train_step,
"epoch": self.epoch,
"accuracy": np.mean(task_acc),
"acc_std": np.std(task_acc),
"loss": np.mean(task_loss),
"k": k,
"top_mode": "test",
"rule_world": self.test_world,
}
# self.logbook.write_metric_logs(metrics)
return metrics, rep_emb_worlds