in shaDow/minibatch.py [0:0]
def instantiate_sampler(self, sampler_config_ensemble, parallelism, bin_adj_files, modes=[TRAIN, VALID, TEST]):
sampler_config_ensemble_ = deepcopy(sampler_config_ensemble)
config_ensemble = []
# e.g., input: [{"method": "ppr", "k": [50, 10]}, {"method": "khop", "depth": [2], "budget": [10]}]
# output: [{"method": "ppr", "k": 50}, {"method": "ppr", "k": 10}, {"method": "khop", "depth": 2, "budget": 10}]
for cfg in sampler_config_ensemble_["configs"]: # different TYPEs of samplers
method = cfg.pop('method')
cnt_cur_sampler = [len(v) for k, v in cfg.items()]
assert len(cnt_cur_sampler) == 0 or max(cnt_cur_sampler) == min(cnt_cur_sampler)
cnt_cur_sampler = 1 if len(cnt_cur_sampler) == 0 else cnt_cur_sampler[0]
cfg['method'] = [method] * cnt_cur_sampler
cfg_decoupled = [{k: v[i] for k, v in cfg.items()} for i in range(cnt_cur_sampler)]
config_ensemble.extend(cfg_decoupled)
self.num_ensemble = len(config_ensemble)
for cfg in config_ensemble:
assert "method" in cfg
assert "size_root" not in cfg or cfg["size_root"] == 1
config_ensemble_mode = {}
for m in modes:
self.batch_size[m] = sampler_config_ensemble_["batch_size"]
config_ensemble_mode[m] = deepcopy(config_ensemble)
# TODO: support different sampler config in val and test
for m, cfg_l in config_ensemble_mode.items():
if m in [VALID, TEST]:
for cfg in cfg_l:
if cfg['method'] == 'ppr_st':
cfg['method'] = 'ppr'
for cfg_mode, cfg_ensemble in config_ensemble_mode.items():
for cfg in cfg_ensemble:
cfg["size_root"] = 1 + (self.prediction_task == 'link') # we want each target to have its own subgraph
cfg["fix_target"] = True # i.e., we differentiate root node from the neighbor nodes (compare with GraphSAINT)
cfg["sequential_traversal"] = True # (mode != "train")
if self.prediction_task == 'link':
cfg['include_target_conn'] = False
if cfg["method"] in ["ppr", 'ppr_st']:
cfg["type_"] = cfg_mode
cfg['name_data'] = self.name_data
cfg["dir_data"] = self.dir_data
cfg['is_transductive'] = self.is_transductive
if self.prediction_task == 'node':
_prep_target = self.raw_entity_set[cfg_mode]
_dup_modes = None
else:
_prep_target = np.arange(self.adj[TEST].indptr.size - 1)
_dup_modes = [TRAIN, VALID, TEST]
cfg['args_preproc'] = {'preproc_targets': _prep_target, 'duplicate_modes': _dup_modes}
aug_feat_ens = [self.aug_feats] * len(cfg_ensemble)
self.graph_sampler[cfg_mode] = GraphSamplerEnsemble(
self.adj[cfg_mode],
self.feat_full[:, :self.dim_feat_raw],
cfg_ensemble,
aug_feat_ens,
max_num_threads=parallelism,
num_subg_per_batch=500,
bin_adj_files=bin_adj_files[cfg_mode],
seed_cpp=self.seed_cpp
)
self.is_stochastic_sampler[cfg_mode] = self.graph_sampler[cfg_mode].is_stochastic