def episode()

in src/engines.py [0:0]


    def episode(self):
        best_valid_mrr, init_epoch_id, step_idx = 0, 0, 0
        exp_train_sampler = self.dataset.get_sampler('train')
        
        for e in range(init_epoch_id, self.max_epochs):
            wandb.run.summary['epoch_id'] = e
            self.model.train()
            pbar = tqdm(total=exp_train_sampler.size)
            while exp_train_sampler.is_epoch(e): # iterate through all batchs inside an epoch
                pbar.update(self.batch_size)
                if self.world == 'LCWA':
                    input_batch_train = exp_train_sampler.batchify(self.batch_size,
                                                                    self.device)
                    predictions, factors = self.model.forward(input_batch_train, score_rel=self.score_rel, score_rhs=self.score_rhs, score_lhs=self.score_lhs)
                    
                    if self.score_rel and self.score_rhs and self.score_lhs:
                        # print('----1----')
                        l_fit = self.loss(predictions[0], input_batch_train[:, 2]) \
                                + self.w_rel * self.loss(predictions[1], input_batch_train[:, 1]) \
                                + self.w_lhs * self.loss(predictions[2], input_batch_train[:, 0])
                    elif self.score_rel and self.score_rhs:
                        # print('----2----')
                        l_fit = self.loss(predictions[0], input_batch_train[:, 2]) + self.w_rel * self.loss(predictions[1], input_batch_train[:, 1])
                    elif self.score_lhs and self.score_rel:
                        # print('----3----')
                        pass
                    elif self.score_rhs and self.score_lhs: # standard
                        # print('----4----')
                        l_fit = self.loss(predictions[0], input_batch_train[:, 2]) + self.loss(predictions[1], input_batch_train[:, 0])
                    elif self.score_rhs: # only rhs
                        # print('----5----')
                        l_fit = self.loss(predictions, input_batch_train[:, 2])
                    elif self.score_rel:
                        # print('----6----')
                        l_fit = self.loss(predictions, input_batch_train[:, 1])
                    elif self.score_lhs:
                        # print('----7----')
                        pass
                    
                    l_reg, l_reg_raw, avg_lmbda = self.regularizer.penalty(input_batch_train, factors) # Note: this shouldn't be included into the computational graph of lambda update
                elif self.world == 'sLCWA+bpr':
                    pos_train, neg_train, label = exp_train_sampler.batchify(self.batch_size,
                                                                                self.device,
                                                                                num_neg=self.num_neg)
                    predictions, factors = self.model.forward_bpr(pos_train, neg_train)
                    l_fit = self.loss(predictions, label)
                    l_reg, l_reg_raw, avg_lmbda = self.regularizer.penalty(
                        torch.cat((pos_train, neg_train), dim=0),
                        factors)
                l = l_fit + l_reg
                self.optimizer.zero_grad()
                l.backward()
                self.optimizer.step()
                    
                if ((step_idx % 1000 == 0 and step_idx > 1000) or (step_idx <= 1000 and step_idx % 100 == 0)): # reduce logging frequency to accelerate 
                    wandb.log({'step_wise/train/l': l.item()}, step=step_idx)
                    wandb.log({'step_wise/train/l_fit': l_fit.item()}, step=step_idx)
                    wandb.log({'step_wise/train/l_reg': l_reg.item()}, step=step_idx)
                    wandb.log({'step_wise/train/l_reg_raw': l_reg_raw.item()}, step=step_idx)
                step_idx += 1
            if e % self.opt['valid'] == 0:
                self.model.eval()
                res_all, res_all_detailed = [], []
                for split in self.dataset.splits:
                    res_s = self.dataset.eval(model=self.model, 
                                              split=split, 
                                              n_queries=-1 if split != 'train' else 1000, # subsample 5000 triples for computing approximated training MRR
                                              n_epochs=e)
                    res_all.append(avg_both(res_s[0], res_s[1]))
                    res_all_detailed.append(res_s[2])
                    
                res = dict(zip(self.dataset.splits, res_all))
                res_detailed = dict(zip(self.dataset.splits, res_all_detailed))
                
                print("\t Epoch: ", e)
                for split in self.dataset.splits:
                    print("\t {}: {}".format(split.upper(), res[split]))
                    wandb.log({'step_wise/{}/mrr'.format(split): res[split]['MRR']}, step=step_idx)
                    wandb.log({'step_wise/{}/hits@1'.format(split): res[split]['hits@[1,3,10]'][0]}, step=step_idx)

                if res[split]['MRR'] > best_valid_mrr:
                    best_valid_mrr = res[split]['MRR']
                    self.model.checkpoint(model_cache_path=self.model_cache_path, epoch_id='best_valid')
                    if self.opt['cache_eval'] is not None:
                        for s in self.dataset.splits:
                            for m in ['lhs', 'rhs']:
                                torch.save(res_detailed[s][m], 
                                           self.opt['cache_eval']+'{s}_{m}.pt'.format(s=s, m=m))
                    wandb.run.summary['best_valid_mrr'] = best_valid_mrr
                    wandb.run.summary['best_valid_epoch'] = e
                    wandb.run.summary['corr_test_mrr'] = res['test']['MRR']
                    wandb.run.summary['corr_test_hits@1'] = res['test']['hits@[1,3,10]'][0]
                    wandb.run.summary['corr_test_hits@3'] = res['test']['hits@[1,3,10]'][1]
                    wandb.run.summary['corr_test_hits@10'] = res['test']['hits@[1,3,10]'][2]
            if best_valid_mrr == 1:
                print('MRR 1, diverged!')
                break
            if best_valid_mrr > 0 and best_valid_mrr < 2e-4:
                if l_reg_raw.item() < 1e-4:
                    print('0 embedding weight, diverged!')
                    break
        self.model.eval()
        mrrs, hits, _ = self.dataset.eval(self.model, 'test', -1)
        print("\n\nTEST : MRR {} Hits {}".format(mrrs, hits))
        wandb.run.summary['is_done'] = True