def one_batch()

in shaDow/minibatch.py [0:0]


    def one_batch(self, mode=TRAIN, ret_raw_idx=False):
        """
        Prepare one batch of training subgraph. For each root, the sampler returns the subgraph adj separatedly. 
        To improve the computation efficiency, we concatenate the batch number of individual adj into a single big adj. 
        i.e., the resulting adj is of the block-diagnol form. 
        
        Such concatenation does not increase computation complexity (since adj is in CSR) but facilitates parallelism. 
        """
        if self.graph_sampler[mode] is None:    # [node pred only] no sampling, return the full graph. 
            self._update_batch_stat(mode, self.batch_size[mode])
            targets = self.raw_entity_set[mode]
            assert ret_raw_idx, "None subg mode should only be used in preproc!"
            return OneBatchSubgraph(
                [self.adj[mode]],
                [self.feat_full],
                self.label_full[targets],
                None,
                [targets],
                None,
                [np.arange(self.adj[mode].shape[0])]
            )
        batch_size_ = self._get_cur_batch_size(mode)
        while self.pool_subg[mode].num_subg < batch_size_:
            self.par_graph_sample(mode)
        if batch_size_ != self.batch_size[mode]:        # end of epoch
            assert batch_size_ < self.batch_size[mode] and self.pool_subg[mode].num_subg == batch_size_
            self.graph_sampler[mode].validate_epoch_end()
        adj_ens, feat_ens, target_ens, feat_aug_ens = [], [], [], []
        label_idx = None
        subgs_ens, size_subg_ens = self.pool_subg[mode].collate(batch_size_)
        size_subg_ens = torch.tensor(size_subg_ens).to(self.dev_torch)
        feat_aug_ens = []
        idx_b_start = self.idx_entity_evaluated[mode]
        idx_b_end = idx_b_start + batch_size_
        for subgs in subgs_ens:       
            assert subgs.target.size == batch_size_ * (1 + (self.prediction_task == 'link'))
            if label_idx is None:
                label_idx = subgs.node[subgs.target]
            else:
                assert np.all(label_idx == subgs.node[subgs.target])
            adj_ens.append(subgs.to_csr_sp())
            feat_ens.append(self.feat_full[subgs.node].to(self.dev_torch))
            target_ens.append(subgs.target)
            label_batch = self.label_epoch[mode][idx_b_start : idx_b_end].to(self.dev_torch)
            feat_aug_ens.append({})
            for candy_augs in {'hops', 'pprs', 'drnls'}.intersection(self.aug_feats):
                candy_aug = candy_augs[:-1] # remove 's'
                feat_aug_ens[-1][candy_augs] = getattr(subgs.entity_enc, f'{candy_aug}2onehot_vec')(
                    getattr(self, f'dim_1hot_{candy_aug}'), return_type='tensor'
                ).type(self.dtype).to(self.dev_torch)
        self._update_batch_stat(mode, batch_size_)
        ret = OneBatchSubgraph(
            adj_ens, feat_ens, label_batch, size_subg_ens, target_ens, feat_aug_ens
        )
        self.profiler.update_subgraph_batch(ret)
        self.profiler.profile()
        if ret_raw_idx:     # TODO: this should support ens as well. multiple subg should have the same raw target idx
            assert ret.num_ens == 1
            ret.idx_raw = [subgs.node for subgs in subgs_ens]
        return ret