def par_sample_ensemble()

in para_graph_sampler/graph_engine/frontend/samplers_ensemble.py [0:0]


    def par_sample_ensemble(self, prediction_task):
        ret = {k: None for k in self.para_sampler.keys()}
        # handle C++ sampler first
        if 'cpp' in self.para_sampler:
            for i, cfg in enumerate(self.sampler_backends['cpp']):
                cfg.cpp_config['return_target_only'] = 'true' if self.return_target_only['cpp'][i] else 'false'
            _args = [cfg.cpp_config for cfg in self.sampler_backends['cpp']]
            _augs = [cfg.cpp_aug for cfg in self.sampler_backends['cpp']]
            ret['cpp'] = self.para_sampler['cpp'].parallel_sampler_ensemble(_args, _augs)
            for i, r in enumerate(ret['cpp']):
                ret['cpp'][i] = self._extract_subgraph_return(r, _args[i], _augs[i], not self.return_target_only['cpp'][i])
            assert len(ret['cpp']) == len(self.sampler_backends['cpp'])
            assert min([len(r) for r in ret['cpp']]) == max([len(r) for r in ret['cpp']])
        # handle python sampler
        if 'python' in self.para_sampler:
            ret['python'] = []
            for i, psam in enumerate(self.para_sampler['python']):
                ret['python'].append(psam.parallel_sample(return_target_only=self.return_target_only['python'][i]))
        ret_ens = self._merge_backend_lists(ret)
        ret_ens_merge = ret_ens
        # check consistency among samplers (regardless of backends)
        for s in zip(*ret_ens_merge):
            _root_idxs = [ss.node if ss.target.size == 0 else ss.node[ss.target] for ss in s]
            _num_roots = set([r.size for r in _root_idxs])
            assert len(_num_roots) == 1 and _num_roots.pop() == 1 + (prediction_task == 'link')
            assert len(set([r[0] for r in _root_idxs])) == 1
        return ret_ens_merge