in shaDow/minibatch.py [0:0]
def one_batch(self, mode=TRAIN, ret_raw_idx=False):
"""
Prepare one batch of training subgraph. For each root, the sampler returns the subgraph adj separatedly.
To improve the computation efficiency, we concatenate the batch number of individual adj into a single big adj.
i.e., the resulting adj is of the block-diagnol form.
Such concatenation does not increase computation complexity (since adj is in CSR) but facilitates parallelism.
"""
if self.graph_sampler[mode] is None: # [node pred only] no sampling, return the full graph.
self._update_batch_stat(mode, self.batch_size[mode])
targets = self.raw_entity_set[mode]
assert ret_raw_idx, "None subg mode should only be used in preproc!"
return OneBatchSubgraph(
[self.adj[mode]],
[self.feat_full],
self.label_full[targets],
None,
[targets],
None,
[np.arange(self.adj[mode].shape[0])]
)
batch_size_ = self._get_cur_batch_size(mode)
while self.pool_subg[mode].num_subg < batch_size_:
self.par_graph_sample(mode)
if batch_size_ != self.batch_size[mode]: # end of epoch
assert batch_size_ < self.batch_size[mode] and self.pool_subg[mode].num_subg == batch_size_
self.graph_sampler[mode].validate_epoch_end()
adj_ens, feat_ens, target_ens, feat_aug_ens = [], [], [], []
label_idx = None
subgs_ens, size_subg_ens = self.pool_subg[mode].collate(batch_size_)
size_subg_ens = torch.tensor(size_subg_ens).to(self.dev_torch)
feat_aug_ens = []
idx_b_start = self.idx_entity_evaluated[mode]
idx_b_end = idx_b_start + batch_size_
for subgs in subgs_ens:
assert subgs.target.size == batch_size_ * (1 + (self.prediction_task == 'link'))
if label_idx is None:
label_idx = subgs.node[subgs.target]
else:
assert np.all(label_idx == subgs.node[subgs.target])
adj_ens.append(subgs.to_csr_sp())
feat_ens.append(self.feat_full[subgs.node].to(self.dev_torch))
target_ens.append(subgs.target)
label_batch = self.label_epoch[mode][idx_b_start : idx_b_end].to(self.dev_torch)
feat_aug_ens.append({})
for candy_augs in {'hops', 'pprs', 'drnls'}.intersection(self.aug_feats):
candy_aug = candy_augs[:-1] # remove 's'
feat_aug_ens[-1][candy_augs] = getattr(subgs.entity_enc, f'{candy_aug}2onehot_vec')(
getattr(self, f'dim_1hot_{candy_aug}'), return_type='tensor'
).type(self.dtype).to(self.dev_torch)
self._update_batch_stat(mode, batch_size_)
ret = OneBatchSubgraph(
adj_ens, feat_ens, label_batch, size_subg_ens, target_ens, feat_aug_ens
)
self.profiler.update_subgraph_batch(ret)
self.profiler.profile()
if ret_raw_idx: # TODO: this should support ens as well. multiple subg should have the same raw target idx
assert ret.num_ens == 1
ret.idx_raw = [subgs.node for subgs in subgs_ens]
return ret