def initialize_datagen()

in cfvpy/selfplay.py [0:0]


    def initialize_datagen(self):
        # Need to preserve ownership of the ref models!
        ref_models = []
        model_lockers = []
        assert torch.cuda.device_count() >= 2, torch.cuda.device_count()
        if self.cfg.selfplay.cpu_gen_threads:
            num_threads = self.cfg.selfplay.cpu_gen_threads
            act_devices = ["cpu"] * num_threads
            logging.info("Will generate on CPU with %d threads", num_threads)
            assert self.cfg.selfplay.models_per_gpu == 1
        else:
            act_devices = [f"cuda:{i}" for i in range(1, torch.cuda.device_count())]
            if self.is_master and self.cfg.selfplay.num_master_threads is not None:
                num_threads = self.cfg.selfplay.num_master_threads
            else:
                num_threads = self.cfg.selfplay.threads_per_gpu * len(act_devices)

        # Don't need mode deviced than threads.
        act_devices = act_devices[:num_threads]

        logging.info("Gpus for generations: %s", act_devices)
        logging.info("Threads: %s", num_threads)
        for act_device in act_devices:
            ref_model = [
                _build_model(
                    act_device,
                    self.cfg.env,
                    self.cfg.model,
                    self.get_model().state_dict(),
                    half=self.cfg.half_inference,
                    jit=True,
                )
                for _ in range(self.cfg.selfplay.models_per_gpu)
            ]
            for model in ref_models:
                model.eval()
            ref_models.extend(ref_model)
            model_locker = cfvpy.rela.ModelLocker(ref_model, act_device)
            model_lockers.append(model_locker)

        replay_params = dict(
            capacity=2 ** 20,
            seed=10001 + self.rank,
            alpha=1.0,
            beta=0.4,
            prefetch=3,
            use_priority=True,
        )
        if self.cfg.replay:
            replay_params.update(self.cfg.replay)
        logging.info("Replay params (per buffer): %s", replay_params)
        replay = cfvpy.rela.ValuePrioritizedReplay(
            **replay_params, compressed_values=False
        )
        if self.cfg.train_policy:
            policy_replay = cfvpy.rela.ValuePrioritizedReplay(
                **replay_params, compressed_values=self.cfg.compress_policy_values
            )
        else:
            policy_replay = None

        context = cfvpy.utils.TimedContext()
        cfr_cfg = create_mdp_config(self.cfg.env)
        for i in range(num_threads):
            thread = cfvpy.rela.create_cfr_thread(
                model_lockers[i % len(model_lockers)],
                replay,
                cfr_cfg,
                self.rank * 1000 + i,
            )
            context.push_env_thread(thread)

        return dict(
            ref_models=ref_models,
            model_lockers=model_lockers,
            replay=replay,
            policy_replay=policy_replay,
            context=context,
        )