def __call__()

in fairdiplomacy/selfplay/exploit.py [0:0]


    def __call__(self):
        # we call this immediately to send the data workers off to the right places
        # After this returns, we're in the master process or training helper processes only.
        self._init_data_loader__workers_never_return()

        self._init_training_device_and_state()

        self._init_metrics_logger()

        if self.ectx.is_training_master:
            CKPT_MAIN_DIR.mkdir(exist_ok=True, parents=True)
            if self.state.value_net_state is not None:
                CKPT_VALUE_DIR.mkdir(exist_ok=True, parents=True)

        # Training device and model state should be correct now, so now make sure that
        # the model is sent to the data generation workers if it's not there yet
        self.send_model_to_workers()

        self.state.model.train()
        if self.state.value_net_state is not None:
            self.state.value_net_state.model.train()
        if self.cfg.trainer.train_as_eval:
            assert self.state.value_net_state is None, "Not supported"
            self.state.model.eval()
            # Cast cuDNN RNN back to train mode.
            self.state.model.apply(_lstm_to_train)
        elif self.cfg.trainer.train_encoder_as_eval:
            assert self.state.value_net_state is None, "Not supported"
            self.state.model.encoder.eval()
        elif self.cfg.trainer.train_decoder_as_eval:
            assert self.state.value_net_state is None, "Not supported"
            self.state.model.policy_decoder.eval()
            self.state.model.policy_decoder.apply(_lstm_to_train)
        elif self.cfg.trainer.train_as_eval_but_batchnorm:
            assert self.state.value_net_state is None, "Not supported"
            self.state.model.eval()
            self.state.model.apply(_lstm_to_train)
            self.state.model.apply(_bn_to_train)

        if not self.research:
            assert self.cfg.num_train_gpus == 1, "Only one training GPU for policy gradients"
            self.state.model.to(self.device)
        else:
            assert self.cfg.num_train_gpus in (1, 2, 4)
            if self.cfg.num_train_gpus:
                assert torch.cuda.device_count() == 8, "Can only go multi-gpu on full machine"
                if self.cfg.num_train_gpus > 1 and self.cfg.use_distributed_data_parallel:
                    self.state.net_state.model = torch.nn.parallel.DistributedDataParallel(
                        self.state.net_state.model,
                        device_ids=(self.ectx.training_ddp_rank,),
                        output_device=self.ectx.training_ddp_rank,
                    )
                    if self.state.value_net_state is not None:
                        self.state.value_net_state.model = torch.nn.parallel.DistributedDataParallel(
                            self.state.value_net_state.model,
                            device_ids=(self.ectx.training_ddp_rank,),
                            output_device=self.ectx.training_ddp_rank,
                        )
                else:
                    self.state.net_state.model = torch.nn.DataParallel(
                        self.state.net_state.model,
                        device_ids=tuple(range(self.cfg.num_train_gpus)),
                    )
                    if self.state.value_net_state is not None:
                        self.state.value_net_state.model = torch.nn.DataParallel(
                            self.state.value_net_state.model,
                            device_ids=tuple(range(self.cfg.num_train_gpus)),
                        )
            else:
                self.state.net_state.model.to(self.device)
                if self.state.value_net_state is not None:
                    self.state.value_net_state.model.to(self.device)

        if torch.cuda.is_available():
            pynvml.nvmlInit()

        if self.research and self.cfg.search_rollout.benchmark_only:
            self.perform_benchmark()
            return

        # Installing requeue handle at the last moment so that we don't requeue zombie job.
        heyhi.maybe_init_requeue_handler(self.on_requeue)

        self.run_training_loop()