def adapt()

in experiments/codes/experiment/inference.py [0:0]


    def adapt(self, k=0, eps=0.05, report=True, patience=7, data_k=-1, num_epochs=-1):
        """ K-shot adaptation. Here k defines the number of minibatches (or updates) 
        the model is exposed to
        
        Keyword Arguments:
            k {int} -- k shot. if k = -1, train till convergence (default: {0})
        """
        self.composition_fn.train()
        self.representation_fn.train()
        mode = "train"
        break_while = False
        convergence_mode = data_k == -1 and k == -1 and num_epochs == -1
        if convergence_mode:
            print("converging till best validation")
        best_epoch_loss = 10000
        counter = 0
        epoch_id = -1
        num_worlds = list(self.test_data.keys())
        assert len(num_worlds) == 1
        self.test_data = self.test_data[num_worlds[0]]
        while True:
            epoch_id += 1
            epoch_loss = []
            epoch_acc = []
            for batch_idx, batch in enumerate(self.test_data[mode]):
                batch.to(self.config.general.device)
                rel_emb = self.representation_fn(batch)
                logits = self.composition_fn(batch, rel_emb)
                loss = self.composition_fn.loss(logits, batch.targets)
                for opt in self.optimizers:
                    opt.zero_grad()
                loss.backward()
                for opt in self.optimizers:
                    opt.step()
                epoch_loss.append(loss.cpu().detach().item())
                predictions, conf = self.composition_fn.predict(logits)
                epoch_acc.append(
                    self.composition_fn.accuracy(predictions, batch.targets)
                    .cpu()
                    .detach()
                    .item()
                )
                epoch_loss_mean = np.mean(epoch_loss)
                self.train_step += 1
                if k > 0:
                    if self.train_step >= k:
                        break_while = True
                        break
            if data_k > 0:
                data_k -= 1
            epoch_loss_mean = np.mean(epoch_loss)
            valid_eval = self.evaluate(mode="valid")
            test_eval = self.evaluate(mode="test")
            if convergence_mode:
                if best_epoch_loss < valid_eval["loss"]:
                    counter += 1
                else:
                    # save the model in tmp loc
                    self.best_validation_save_dir = self.save_model(
                        epoch=epoch_id, save_dir=self.best_validation_save_dir
                    )
                    print(
                        "saved best model in {}".format(self.best_validation_save_dir)
                    )
                    counter = 0
            if report:
                metrics = {
                    "mode": mode,
                    "minibatch": self.train_step,
                    "loss": epoch_loss_mean,
                    "valid_loss": valid_eval["loss"],
                    "best_valid_loss": best_epoch_loss,
                    "test_acc": test_eval["accuracy"],
                    "accuracy": np.mean(epoch_acc),
                    "epoch": self.epoch,
                    "rule_world": self.test_world,
                    "patience_counter": counter,
                }
                print(metrics)
                # self.logbook.write_metric_logs(metrics)
            # if np.mean(epoch_loss_mean) <= eps:
            #     break
            if data_k == 0:
                break
            # else:
            #     patience = original_patience
            if convergence_mode:
                best_epoch_loss = min(best_epoch_loss, valid_eval["loss"])
                if counter >= patience:
                    break
            if break_while:
                break
            if num_epochs >= 0:
                if epoch_id >= num_epochs:
                    break
        if convergence_mode:
            # reload model from tmp loc
            self.load_model(load_dir=self.best_validation_save_dir)
            shutil.rmtree(self.best_validation_save_dir)
            self.best_validation_save_dir = None