def instantiate_sampler()

in shaDow/minibatch.py [0:0]


    def instantiate_sampler(self, sampler_config_ensemble, parallelism, bin_adj_files, modes=[TRAIN, VALID, TEST]):
        sampler_config_ensemble_ = deepcopy(sampler_config_ensemble)
        config_ensemble = []
        # e.g., input: [{"method": "ppr", "k": [50, 10]}, {"method": "khop", "depth": [2], "budget": [10]}]
        #       output: [{"method": "ppr", "k": 50}, {"method": "ppr", "k": 10}, {"method": "khop", "depth": 2, "budget": 10}]
        for cfg in sampler_config_ensemble_["configs"]:  # different TYPEs of samplers
            method = cfg.pop('method')
            cnt_cur_sampler = [len(v) for k, v in cfg.items()]
            assert len(cnt_cur_sampler) == 0 or max(cnt_cur_sampler) == min(cnt_cur_sampler)
            cnt_cur_sampler = 1 if len(cnt_cur_sampler) == 0 else cnt_cur_sampler[0]
            cfg['method'] = [method] * cnt_cur_sampler
            cfg_decoupled = [{k: v[i] for k, v in cfg.items()} for i in range(cnt_cur_sampler)]
            config_ensemble.extend(cfg_decoupled)
        self.num_ensemble = len(config_ensemble)
        for cfg in config_ensemble:
            assert "method" in cfg
            assert "size_root" not in cfg or cfg["size_root"] == 1
        config_ensemble_mode = {}
        for m in modes:
            self.batch_size[m] = sampler_config_ensemble_["batch_size"]
            config_ensemble_mode[m] = deepcopy(config_ensemble)
        # TODO: support different sampler config in val and test
        for m, cfg_l in config_ensemble_mode.items():
            if m in [VALID, TEST]:
                for cfg in cfg_l:
                    if cfg['method'] == 'ppr_st':
                        cfg['method'] = 'ppr'
        for cfg_mode, cfg_ensemble in config_ensemble_mode.items():
            for cfg in cfg_ensemble:
                cfg["size_root"] = 1 + (self.prediction_task == 'link')     # we want each target to have its own subgraph
                cfg["fix_target"] = True             # i.e., we differentiate root node from the neighbor nodes (compare with GraphSAINT)
                cfg["sequential_traversal"] = True   # (mode != "train")
                if self.prediction_task == 'link':
                    cfg['include_target_conn'] = False
                if cfg["method"] in ["ppr", 'ppr_st']:
                    cfg["type_"] = cfg_mode
                    cfg['name_data'] = self.name_data
                    cfg["dir_data"] = self.dir_data
                    cfg['is_transductive'] = self.is_transductive
                    if self.prediction_task == 'node':
                        _prep_target = self.raw_entity_set[cfg_mode]
                        _dup_modes = None
                    else:
                        _prep_target = np.arange(self.adj[TEST].indptr.size - 1)
                        _dup_modes = [TRAIN, VALID, TEST]
                    cfg['args_preproc'] = {'preproc_targets': _prep_target, 'duplicate_modes': _dup_modes}
            aug_feat_ens = [self.aug_feats] * len(cfg_ensemble)
            self.graph_sampler[cfg_mode] = GraphSamplerEnsemble(
                self.adj[cfg_mode], 
                self.feat_full[:, :self.dim_feat_raw], 
                cfg_ensemble, 
                aug_feat_ens, 
                max_num_threads=parallelism, 
                num_subg_per_batch=500, 
                bin_adj_files=bin_adj_files[cfg_mode], 
                seed_cpp=self.seed_cpp
            )
            self.is_stochastic_sampler[cfg_mode] = self.graph_sampler[cfg_mode].is_stochastic