in cfvpy/selfplay.py [0:0]
def initialize_datagen(self):
# Need to preserve ownership of the ref models!
ref_models = []
model_lockers = []
assert torch.cuda.device_count() >= 2, torch.cuda.device_count()
if self.cfg.selfplay.cpu_gen_threads:
num_threads = self.cfg.selfplay.cpu_gen_threads
act_devices = ["cpu"] * num_threads
logging.info("Will generate on CPU with %d threads", num_threads)
assert self.cfg.selfplay.models_per_gpu == 1
else:
act_devices = [f"cuda:{i}" for i in range(1, torch.cuda.device_count())]
if self.is_master and self.cfg.selfplay.num_master_threads is not None:
num_threads = self.cfg.selfplay.num_master_threads
else:
num_threads = self.cfg.selfplay.threads_per_gpu * len(act_devices)
# Don't need mode deviced than threads.
act_devices = act_devices[:num_threads]
logging.info("Gpus for generations: %s", act_devices)
logging.info("Threads: %s", num_threads)
for act_device in act_devices:
ref_model = [
_build_model(
act_device,
self.cfg.env,
self.cfg.model,
self.get_model().state_dict(),
half=self.cfg.half_inference,
jit=True,
)
for _ in range(self.cfg.selfplay.models_per_gpu)
]
for model in ref_models:
model.eval()
ref_models.extend(ref_model)
model_locker = cfvpy.rela.ModelLocker(ref_model, act_device)
model_lockers.append(model_locker)
replay_params = dict(
capacity=2 ** 20,
seed=10001 + self.rank,
alpha=1.0,
beta=0.4,
prefetch=3,
use_priority=True,
)
if self.cfg.replay:
replay_params.update(self.cfg.replay)
logging.info("Replay params (per buffer): %s", replay_params)
replay = cfvpy.rela.ValuePrioritizedReplay(
**replay_params, compressed_values=False
)
if self.cfg.train_policy:
policy_replay = cfvpy.rela.ValuePrioritizedReplay(
**replay_params, compressed_values=self.cfg.compress_policy_values
)
else:
policy_replay = None
context = cfvpy.utils.TimedContext()
cfr_cfg = create_mdp_config(self.cfg.env)
for i in range(num_threads):
thread = cfvpy.rela.create_cfr_thread(
model_lockers[i % len(model_lockers)],
replay,
cfr_cfg,
self.rank * 1000 + i,
)
context.push_env_thread(thread)
return dict(
ref_models=ref_models,
model_lockers=model_lockers,
replay=replay,
policy_replay=policy_replay,
context=context,
)