def _init_training_device_and_state()

in fairdiplomacy/selfplay/exploit.py [0:0]


    def _init_training_device_and_state(self):
        if not torch.cuda.is_available():
            logging.warning("No CUDA found!")
            self.device = "cpu"
        else:
            if self.cfg.use_distributed_data_parallel:
                # For now, ddp rank equals gpu index
                self.device = f"cuda:{self.ectx.training_ddp_rank}"  # Training device.
            else:
                self.device = "cuda"  # Training device.

        if self.cfg.value_model_path is not None:
            net, net_args = load_diplomacy_model_model_and_args(
                self.cfg.model_path,
                map_location=self.device,
                eval=True,
                override_has_policy=True,
                override_has_value=False,
            )
            optim, lr_scheduler = build_optimizer(net, self.cfg.optimizer)

            value_net, value_net_args = load_diplomacy_model_model_and_args(
                self.cfg.value_model_path,
                map_location=self.device,
                eval=True,
                override_has_policy=False,
                override_has_value=True,
            )

            value_optim, value_lr_scheduler = build_optimizer(
                value_net, self.cfg.value_optimizer or self.cfg.optimizer
            )
            value_state = NetTrainingState(
                args=value_net_args,
                model=value_net,
                optimizer=value_optim,
                scheduler=value_lr_scheduler,
            )
        else:
            net, net_args = load_diplomacy_model_model_and_args(
                self.cfg.model_path, map_location=self.device, eval=True,
            )
            optim, lr_scheduler = build_optimizer(net, self.cfg.optimizer)
            value_state = None
        self.state = TrainerState(
            net_state=NetTrainingState(
                model=net, optimizer=optim, scheduler=lr_scheduler, args=net_args,
            ),
            value_net_state=value_state,
        )
        if self.cfg.reset_agent_weights:

            def _reset(module):
                if hasattr(module, "reset_parameters"):
                    module.reset_parameters()

            net.apply(_reset)
        if REQUEUE_CKPT.exists():
            logging.info("Found requeue checkpoint: %s", REQUEUE_CKPT.resolve())
            self.state = TrainerState.load(REQUEUE_CKPT, self.state, self.device)
        elif self.cfg.requeue_ckpt_path:
            logging.info("Using explicit requeue checkpoint: %s", self.cfg.requeue_ckpt_path)
            p = pathlib.Path(self.cfg.requeue_ckpt_path)
            assert p.exists(), p
            self.state = TrainerState.load(p, self.state, self.device)
        else:
            if self.state.value_net_state is not None:
                if CKPT_VALUE_DIR.exists() and list(CKPT_VALUE_DIR.iterdir()):
                    last_ckpt = max(CKPT_VALUE_DIR.iterdir(), key=str)
                    logging.info(
                        "Found existing VALUE checkpoint folder. Will load last one: %s", last_ckpt
                    )
                    self.state.value_net_state = NetTrainingState.load(
                        last_ckpt, self.state.value_net_state, self.device
                    )
                else:
                    logging.info("No VALUE checkpoint found")
            if self.state.net_state is not None:
                if CKPT_MAIN_DIR.exists() and list(CKPT_MAIN_DIR.iterdir()):
                    last_ckpt = max(CKPT_MAIN_DIR.iterdir(), key=str)
                    logging.info(
                        "Found existing MAIN checkpoint folder. Will load last one: %s", last_ckpt
                    )
                    self.state.net_state = NetTrainingState.load(
                        last_ckpt, self.state.net_state, self.device
                    )
                else:
                    logging.info("No MAIN checkpoint found")

        if self.cfg.use_distributed_data_parallel:
            torch_ddp_init_fname = get_torch_ddp_init_fname()
            logging.info(f"Using {torch_ddp_init_fname} to coordinate launch of ddp processes.")
            logging.info("Waiting for distributed data parallel helpers to sync up")
            torch.distributed.init_process_group(
                "nccl",
                init_method=f"file://{torch_ddp_init_fname}",
                rank=self.ectx.training_ddp_rank,
                world_size=self.ectx.ddp_world_size,
            )
            logging.info("Distributed data parallel helpers synced, proceeding with training")