def do_step_research()

in fairdiplomacy/selfplay/exploit.py [0:0]


    def do_step_research(self, *, counters: collections.defaultdict, use_grad_clip: bool):
        device = self.device
        timings = TimingCtx()
        with timings("data_gen"):
            research_batch: ReSearchRolloutBatch = self.data_loader.get_batch()  # type: ignore

        do_search_policy_loss = (
            self.is_policy_being_trained
            and self.random.random() <= self.cfg.search_policy_update_prob
        )

        main_net_has_grads = value_net_has_grads = False
        per_dataloader_batch_size = self.cfg.search_rollout.batch_size // self.ectx.ddp_world_size
        with timings("to_cuda"):
            rewards = research_batch.rewards.to(device)
            if self.cfg.search_rollout.buffer.shuffle:
                assert list(rewards.shape) == [
                    1,
                    self.cfg.search_rollout.chunk_length * per_dataloader_batch_size,
                    len(POWERS),
                ], rewards.shape
            else:
                assert list(rewards.shape) == [
                    self.cfg.search_rollout.chunk_length,
                    per_dataloader_batch_size,
                    len(POWERS),
                ], rewards.shape
            obs = {k: v.to(device) for k, v in research_batch.observations.items()}
            done = research_batch.done.to(device)
            is_search_policy_valid = research_batch.is_search_policy_valid.to(device)
            is_explore = research_batch.is_explore.to(device)
            is_dead = (research_batch.scores < 1e-3).float().to(device)
            targets = research_batch.targets.to(device)
            is_move_phase = (research_batch.phase_type == ord("M")).to(device)
            years = research_batch.years.to(device)
            if is_explore.all():
                logging.warning("Whole batch of explore!!! Skipping")
                return
            if do_search_policy_loss:
                search_policy_probs, search_policy_orders, blueprint_probs = (
                    research_batch.search_policy_probs.to(device),
                    research_batch.search_policy_orders.long().to(device),
                    research_batch.blueprint_probs.to(device),
                )
                if self.cfg.search_ev_loss is not None:
                    search_policy_evs = research_batch.search_policy_evs.to(device)
            else:
                search_policy_probs = search_policy_orders = None

        with timings("net"):
            loss = torch.tensor(0.0, device=device)
            losses = {}
            if self.is_value_being_trained:
                if self.state.value_net_state is not None:
                    value_net_has_grads = True
                else:
                    main_net_has_grads = True

                # Shape: [T, B, 7].
                _, _, _, predicted_values = self.state.value_model(
                    **nest.map(lambda x: x.flatten(end_dim=1), obs),
                    temperature=1.0,
                    need_policy=False,
                )
                predicted_values = predicted_values.reshape(rewards.shape)

                # Note, if you even change this, you have to propogate discounting
                # to search_data_loader akin to data_loader.
                assert self.cfg.discounting == 1.0, "Discounting is not supported for ReSearch"

                critic_mses = torch.nn.functional.mse_loss(
                    targets, predicted_values, reduction="none"
                )
                losses["critic"] = critic_mses.mean()
                loss = loss + self.cfg.critic_weight * losses["critic"]

            if do_search_policy_loss:
                main_net_has_grads = True
                if self.cfg.search_ev_loss is not None:
                    policy_loss_targets = evs_to_policy(
                        search_policy_evs,
                        temperature=self.cfg.search_ev_loss.temperature,
                        use_softmax=self.cfg.search_ev_loss.use_softmax,
                    )
                else:
                    policy_loss_targets = search_policy_probs

                (
                    search_policy_loss,
                    search_policy_metrics,
                ) = compute_search_policy_cross_entropy_sampled(
                    self.state.model,
                    obs,
                    search_policy_orders,
                    policy_loss_targets,
                    blueprint_probs,
                    mask=is_search_policy_valid,
                    mse_loss=self.cfg.search_policy_mse,
                    mse_bp_normalized=self.cfg.search_policy_mse_bp_normalized,
                    mse_bp_upper_bound=self.cfg.search_policy_mse_bp_upper_bound,
                    is_move_phase=is_move_phase,
                    using_ddp=self.cfg.use_distributed_data_parallel,
                )
                for k, v in search_policy_metrics.items():
                    counters[k].update(v)

                losses["search_policy"] = search_policy_loss
                loss = loss + search_policy_loss * self.cfg.search_policy_weight
                if (
                    self.cfg.sampled_entropy_weight is not None
                    and self.cfg.sampled_entropy_weight > 0.0
                ):
                    e_loss, e_mean = sample_and_compute_sampled_entropy_loss(self.state.model, obs)
                    losses["policy_entropy_loss"] = e_loss
                    losses["policy_entropy"] = e_mean
                    loss = loss + e_loss * self.cfg.sampled_entropy_weight
            else:
                search_policy_loss = None

            self.state.optimizer.zero_grad()
            if self.state.value_net_state is not None:
                self.state.value_net_state.optimizer.zero_grad()

            loss.backward()

            if use_grad_clip:
                if main_net_has_grads:
                    g_norm_tensor = clip_grad_norm_(
                        self.state.model.parameters(), self.cfg.optimizer.grad_clip
                    )
                else:
                    g_norm_tensor = None
                if value_net_has_grads:
                    value_grad_clip = (self.cfg.value_optimizer or self.cfg.optimizer).grad_clip
                    g_norm_value_tensor = clip_grad_norm_(
                        self.state.value_net_state.model.parameters(), value_grad_clip
                    )
                else:
                    g_norm_value_tensor = None

            if (
                not self.cfg.trainer.max_updates
                or self.state.global_step < self.cfg.trainer.max_updates
            ):
                if main_net_has_grads:
                    self.state.net_state.optimizer.step()
                if value_net_has_grads:
                    self.state.value_net_state.optimizer.step()
            # Sync to make sure timing is correct.
            loss.item()

        with timings("metrics"), torch.no_grad():
            last_count = done.long().sum()

            time_bsz = rewards.shape[0] * rewards.shape[1]
            if self.is_value_being_trained:
                critic_end_mses = critic_mses[done].sum()

            if use_grad_clip:
                if g_norm_tensor is not None:
                    g_norm = g_norm_tensor.item()
                    counters["optim/grad_max"].update(g_norm)
                    counters["optim/grad_mean"].update(g_norm)
                    counters["optim/grad_clip_ratio"].update(
                        int(g_norm >= self.cfg.optimizer.grad_clip - 1e-5)
                    )
                if g_norm_value_tensor is not None:
                    g_norm_value = g_norm_value_tensor.item()
                    counters["optim/grad_value_max"].update(g_norm_value)
                    counters["optim/grad_value_mean"].update(g_norm_value)
                    counters["optim/grad_value_clip_ratio"].update(
                        int(g_norm_value >= value_grad_clip - 1e-5)
                    )
            for key, value in losses.items():
                counters[f"loss/{key}"].update(value)
            counters["loss/total"].update(loss.item())

            explored_on_the_right = research_batch.explored_on_the_right
            if self.is_value_being_trained:
                counters["loss/critic_no_explore"].update(
                    critic_mses[~explored_on_the_right].sum(), explored_on_the_right.long().sum()
                )
                counters["loss/critic_last"].update(critic_end_mses, last_count)
            counters["loss/is_explore"].update(is_explore.long().sum(), (1 - is_dead).sum())
            counters["loss/offpolicy_part"].update(
                explored_on_the_right.long().sum(), (1 - is_dead).sum()
            )
            counters["loss/is_search_policy_valid"].update(
                is_search_policy_valid.float().sum(), is_search_policy_valid.numel()
            )

            if do_search_policy_loss:
                # What's the entopy of the search policy.
                counters["loss/entropy_search"].update(
                    compute_search_policy_entropy(
                        search_policy_orders, search_policy_probs, mask=is_search_policy_valid
                    )
                )
                counters["loss/entropy_search_moves"].update(
                    compute_search_policy_entropy(
                        search_policy_orders,
                        search_policy_probs,
                        mask=is_search_policy_valid * is_move_phase,
                    )
                )
                first_two_phases_move_mask = (
                    is_search_policy_valid * is_move_phase * (years == 1901)
                )
                if first_two_phases_move_mask.any():
                    counters["loss/entropy_search_moves_1901"].update(
                        compute_search_policy_entropy(
                            search_policy_orders,
                            search_policy_probs,
                            mask=first_two_phases_move_mask,
                        )
                    )
                if self.cfg.search_ev_loss is not None:
                    counters["loss/entropy_search_from_evs"].update(
                        compute_search_policy_entropy(
                            search_policy_orders, policy_loss_targets, mask=is_search_policy_valid
                        )
                    )
                phase_bp_sums = blueprint_probs.flatten(end_dim=1)[
                    is_search_policy_valid.flatten(end_dim=1)
                ]
                phase_bp_sums = phase_bp_sums.sum(-1).view(-1)
                phase_bp_sums = phase_bp_sums[phase_bp_sums > 1e-10]
                counters["loss/bp_share"].update(phase_bp_sums.sum(), len(phase_bp_sums))

            counters["reward/mean"].update(rewards.sum(), time_bsz)
            # Rewards at the end of episodes. We precompute everything
            # before adding to counters to pipeline things when
            # possible.
            last_rewards = rewards[done]
            last_sum = last_rewards.sum()
            if self.is_value_being_trained:
                # Mean predicted value for dead powers.
                counters["value/mean_dead"].update(
                    (predicted_values * is_dead).sum(), is_dead.sum()
                )
            counters["reward/last"].update(last_sum, last_count)
            counters["reward_solo/last"].update((last_rewards > 0.9).float().sum(), last_count)
            for i, power in enumerate(POWERS):
                power_rewards = last_rewards[..., i]
                counters[f"reward/last_{power}"].update(power_rewards.sum(), last_count)
                counters[f"reward_solo/last_{power}"].update(
                    (power_rewards > 0.9).sum(), last_count
                )

            # Average batch size in phases, i.e., T * B.
            counters["size/batch"].update(time_bsz * self.ectx.ddp_world_size)
            if self.ectx.ddp_world_size > 1:
                counters["size/batch_local"].update(time_bsz)
            # Average number of phases per episode.
            counters["size/episode"].update(time_bsz, last_count)
            # Average number of move phases per episode.
            counters["size/episode_moves"].update(is_move_phase.float().sum(), last_count)

        with timings("sync"), torch.no_grad():
            if self.state.global_step % self.cfg.trainer.save_sync_checkpoint_every == 0:
                self.send_model_to_workers()

        # Doing outside of the context to capture the context's timing.
        for key, value in timings.items():
            counters[f"time/{key}"].update(value)