in src/engines.py [0:0]
def episode(self):
best_valid_mrr, init_epoch_id, step_idx = 0, 0, 0
exp_train_sampler = self.dataset.get_sampler('train')
for e in range(init_epoch_id, self.max_epochs):
wandb.run.summary['epoch_id'] = e
self.model.train()
pbar = tqdm(total=exp_train_sampler.size)
while exp_train_sampler.is_epoch(e): # iterate through all batchs inside an epoch
pbar.update(self.batch_size)
if self.world == 'LCWA':
input_batch_train = exp_train_sampler.batchify(self.batch_size,
self.device)
predictions, factors = self.model.forward(input_batch_train, score_rel=self.score_rel, score_rhs=self.score_rhs, score_lhs=self.score_lhs)
if self.score_rel and self.score_rhs and self.score_lhs:
# print('----1----')
l_fit = self.loss(predictions[0], input_batch_train[:, 2]) \
+ self.w_rel * self.loss(predictions[1], input_batch_train[:, 1]) \
+ self.w_lhs * self.loss(predictions[2], input_batch_train[:, 0])
elif self.score_rel and self.score_rhs:
# print('----2----')
l_fit = self.loss(predictions[0], input_batch_train[:, 2]) + self.w_rel * self.loss(predictions[1], input_batch_train[:, 1])
elif self.score_lhs and self.score_rel:
# print('----3----')
pass
elif self.score_rhs and self.score_lhs: # standard
# print('----4----')
l_fit = self.loss(predictions[0], input_batch_train[:, 2]) + self.loss(predictions[1], input_batch_train[:, 0])
elif self.score_rhs: # only rhs
# print('----5----')
l_fit = self.loss(predictions, input_batch_train[:, 2])
elif self.score_rel:
# print('----6----')
l_fit = self.loss(predictions, input_batch_train[:, 1])
elif self.score_lhs:
# print('----7----')
pass
l_reg, l_reg_raw, avg_lmbda = self.regularizer.penalty(input_batch_train, factors) # Note: this shouldn't be included into the computational graph of lambda update
elif self.world == 'sLCWA+bpr':
pos_train, neg_train, label = exp_train_sampler.batchify(self.batch_size,
self.device,
num_neg=self.num_neg)
predictions, factors = self.model.forward_bpr(pos_train, neg_train)
l_fit = self.loss(predictions, label)
l_reg, l_reg_raw, avg_lmbda = self.regularizer.penalty(
torch.cat((pos_train, neg_train), dim=0),
factors)
l = l_fit + l_reg
self.optimizer.zero_grad()
l.backward()
self.optimizer.step()
if ((step_idx % 1000 == 0 and step_idx > 1000) or (step_idx <= 1000 and step_idx % 100 == 0)): # reduce logging frequency to accelerate
wandb.log({'step_wise/train/l': l.item()}, step=step_idx)
wandb.log({'step_wise/train/l_fit': l_fit.item()}, step=step_idx)
wandb.log({'step_wise/train/l_reg': l_reg.item()}, step=step_idx)
wandb.log({'step_wise/train/l_reg_raw': l_reg_raw.item()}, step=step_idx)
step_idx += 1
if e % self.opt['valid'] == 0:
self.model.eval()
res_all, res_all_detailed = [], []
for split in self.dataset.splits:
res_s = self.dataset.eval(model=self.model,
split=split,
n_queries=-1 if split != 'train' else 1000, # subsample 5000 triples for computing approximated training MRR
n_epochs=e)
res_all.append(avg_both(res_s[0], res_s[1]))
res_all_detailed.append(res_s[2])
res = dict(zip(self.dataset.splits, res_all))
res_detailed = dict(zip(self.dataset.splits, res_all_detailed))
print("\t Epoch: ", e)
for split in self.dataset.splits:
print("\t {}: {}".format(split.upper(), res[split]))
wandb.log({'step_wise/{}/mrr'.format(split): res[split]['MRR']}, step=step_idx)
wandb.log({'step_wise/{}/hits@1'.format(split): res[split]['hits@[1,3,10]'][0]}, step=step_idx)
if res[split]['MRR'] > best_valid_mrr:
best_valid_mrr = res[split]['MRR']
self.model.checkpoint(model_cache_path=self.model_cache_path, epoch_id='best_valid')
if self.opt['cache_eval'] is not None:
for s in self.dataset.splits:
for m in ['lhs', 'rhs']:
torch.save(res_detailed[s][m],
self.opt['cache_eval']+'{s}_{m}.pt'.format(s=s, m=m))
wandb.run.summary['best_valid_mrr'] = best_valid_mrr
wandb.run.summary['best_valid_epoch'] = e
wandb.run.summary['corr_test_mrr'] = res['test']['MRR']
wandb.run.summary['corr_test_hits@1'] = res['test']['hits@[1,3,10]'][0]
wandb.run.summary['corr_test_hits@3'] = res['test']['hits@[1,3,10]'][1]
wandb.run.summary['corr_test_hits@10'] = res['test']['hits@[1,3,10]'][2]
if best_valid_mrr == 1:
print('MRR 1, diverged!')
break
if best_valid_mrr > 0 and best_valid_mrr < 2e-4:
if l_reg_raw.item() < 1e-4:
print('0 embedding weight, diverged!')
break
self.model.eval()
mrrs, hits, _ = self.dataset.eval(self.model, 'test', -1)
print("\n\nTEST : MRR {} Hits {}".format(mrrs, hits))
wandb.run.summary['is_done'] = True