in experiments/codes/experiment/inference.py [0:0]
def adapt(self, k=0, eps=0.05, report=True, patience=7, data_k=-1, num_epochs=-1):
""" K-shot adaptation. Here k defines the number of minibatches (or updates)
the model is exposed to
Keyword Arguments:
k {int} -- k shot. if k = -1, train till convergence (default: {0})
"""
self.composition_fn.train()
self.representation_fn.train()
mode = "train"
break_while = False
convergence_mode = data_k == -1 and k == -1 and num_epochs == -1
if convergence_mode:
print("converging till best validation")
best_epoch_loss = 10000
counter = 0
epoch_id = -1
num_worlds = list(self.test_data.keys())
assert len(num_worlds) == 1
self.test_data = self.test_data[num_worlds[0]]
while True:
epoch_id += 1
epoch_loss = []
epoch_acc = []
for batch_idx, batch in enumerate(self.test_data[mode]):
batch.to(self.config.general.device)
rel_emb = self.representation_fn(batch)
logits = self.composition_fn(batch, rel_emb)
loss = self.composition_fn.loss(logits, batch.targets)
for opt in self.optimizers:
opt.zero_grad()
loss.backward()
for opt in self.optimizers:
opt.step()
epoch_loss.append(loss.cpu().detach().item())
predictions, conf = self.composition_fn.predict(logits)
epoch_acc.append(
self.composition_fn.accuracy(predictions, batch.targets)
.cpu()
.detach()
.item()
)
epoch_loss_mean = np.mean(epoch_loss)
self.train_step += 1
if k > 0:
if self.train_step >= k:
break_while = True
break
if data_k > 0:
data_k -= 1
epoch_loss_mean = np.mean(epoch_loss)
valid_eval = self.evaluate(mode="valid")
test_eval = self.evaluate(mode="test")
if convergence_mode:
if best_epoch_loss < valid_eval["loss"]:
counter += 1
else:
# save the model in tmp loc
self.best_validation_save_dir = self.save_model(
epoch=epoch_id, save_dir=self.best_validation_save_dir
)
print(
"saved best model in {}".format(self.best_validation_save_dir)
)
counter = 0
if report:
metrics = {
"mode": mode,
"minibatch": self.train_step,
"loss": epoch_loss_mean,
"valid_loss": valid_eval["loss"],
"best_valid_loss": best_epoch_loss,
"test_acc": test_eval["accuracy"],
"accuracy": np.mean(epoch_acc),
"epoch": self.epoch,
"rule_world": self.test_world,
"patience_counter": counter,
}
print(metrics)
# self.logbook.write_metric_logs(metrics)
# if np.mean(epoch_loss_mean) <= eps:
# break
if data_k == 0:
break
# else:
# patience = original_patience
if convergence_mode:
best_epoch_loss = min(best_epoch_loss, valid_eval["loss"])
if counter >= patience:
break
if break_while:
break
if num_epochs >= 0:
if epoch_id >= num_epochs:
break
if convergence_mode:
# reload model from tmp loc
self.load_model(load_dir=self.best_validation_save_dir)
shutil.rmtree(self.best_validation_save_dir)
self.best_validation_save_dir = None