def do_step_policy_gradient()

in fairdiplomacy/selfplay/exploit.py [0:0]


    def do_step_policy_gradient(self, *, counters: collections.defaultdict, use_grad_clip: bool):
        device = self.device
        timings = TimingCtx()
        with timings("data_gen"):
            (
                (power_ids, obs, rewards, actions, behavior_action_logprobs, done),
                rollout_scores_per_power,
            ) = self.data_loader.get_batch()

        with timings("to_cuda"):
            actions = actions.to(device)
            rewards = rewards.to(device)
            power_ids = power_ids.to(device)
            obs = {k: v.to(device) for k, v in obs.items()}
            cand_actions = obs.pop("cand_indices")
            behavior_action_logprobs = behavior_action_logprobs.to(device)
            done = done.to(device)

        with timings("net"):
            # Shape: _, [B, 17], [B, S, 469], [B, 7].
            # policy_cand_actions has the same information as actions,
            # but uses local indices to match policy logits.
            assert EOS_IDX == -1, "Rewrite the code to remove the assumption"
            _, _, policy_logits, sc_values = self.state.model(
                **obs,
                temperature=1.0,
                teacher_force_orders=actions.clamp(0),  # EOS_IDX = -1 -> 0
                x_power=power_ids.view(-1, 1).repeat(1, MAX_SEQ_LEN),
            )
            cand_actions = cand_actions[:, : policy_logits.shape[1]]

            # Shape: [B].
            sc_values = sc_values.gather(1, power_ids.unsqueeze(1)).squeeze(1)

            # Removing absolute order ids to not use them by accident.
            # Will use relative order ids (cand_actions) from now on.
            del actions

            if self.cfg.rollout.do_not_split_rollouts:
                # Asssumes that episode actually ends.
                bootstrap_value = torch.zeros_like(sc_values[-1])
            else:
                # Reducing batch size by one. Deleting things that are
                # too lazy to adjsut to avoid artifacts.
                bootstrap_value = sc_values[-1].detach()
                sc_values = sc_values[:-1]
                cand_actions = cand_actions[:-1]
                policy_logits = policy_logits[:-1]
                rewards = rewards[:-1]
                power_ids = power_ids[:-1]
                del obs
                behavior_action_logprobs = behavior_action_logprobs[:-1]
                done = done[:-1]

            # Shape: [B].
            discounts = (~done).float() * self.cfg.discounting

            # Shape: [B, 17].
            mask = (cand_actions != EOS_IDX).float()

            # Shape: [B].
            policy_action_logprobs = order_logits_to_action_logprobs(
                policy_logits, cand_actions, mask
            )

            vtrace_returns = vtrace_from_logprobs_no_batch(
                log_rhos=policy_action_logprobs - behavior_action_logprobs,
                discounts=discounts,
                rewards=rewards,
                values=sc_values,
                bootstrap_value=bootstrap_value,
            )

            critic_mses = 0.5 * ((vtrace_returns.vs.detach() - sc_values) ** 2)

            losses = dict(
                actor=compute_policy_gradient_loss(
                    policy_action_logprobs, vtrace_returns.pg_advantages
                ),
                critic=critic_mses.mean(),
                # TODO(akhti): it's incorrect to apply this to
                # per-position order distribution instead of action
                # distribution.
                entropy=compute_entropy_loss(policy_logits, mask),
            )

            loss = (
                losses["actor"]
                + self.cfg.critic_weight * losses["critic"]
                + self.cfg.entropy_weight * losses["entropy"]
            )
            if self.cfg.sampled_entropy_weight:
                loss = loss + self.cfg.sampled_entropy_weight * compute_sampled_entropy_loss(
                    policy_action_logprobs
                )

            self.state.optimizer.zero_grad()
            loss.backward()

            if use_grad_clip:
                g_norm_tensor = clip_grad_norm_(
                    self.state.model.parameters(), self.cfg.optimizer.grad_clip
                )

            if (
                not self.cfg.trainer.max_updates
                or self.state.global_step < self.cfg.trainer.max_updates
            ):
                self.state.optimizer.step()
            # Sync to make sure timing is correct.
            loss.item()

        with timings("metrics"), torch.no_grad():
            last_count = done.long().sum()
            critic_end_mses = critic_mses[done].sum()

            if use_grad_clip:
                g_norm = g_norm_tensor.item()
                counters["optim/grad_max"].update(g_norm)
                counters["optim/grad_mean"].update(g_norm)
                counters["optim/grad_clip_ratio"].update(
                    int(g_norm >= self.cfg.optimizer.grad_clip - 1e-5)
                )
            for key, value in losses.items():
                counters[f"loss/{key}"].update(value)
            counters["loss/total"].update(loss.item())
            for power_id, rollout_scores in rollout_scores_per_power.items():
                prefix = f"score_{POWERS[power_id]}" if power_id is not None else "score"
                for key, value in rollout_scores.items():
                    if key != "num_games":
                        counters[f"{prefix}/{key}"].update(value, rollout_scores["num_games"])
                    else:
                        counters[f"{prefix}/{key}"].update(value)

            counters["loss/critic_last"].update(critic_end_mses, last_count)

            counters["reward/mean"].update(rewards.sum(), len(rewards))
            # Rewards at the end of episodes. We precompute everything
            # before adding to counters to pipeline things when
            # possible.
            last_rewards = rewards[done]
            last_sum = last_rewards.sum()
            # tensor [num_powers, num_dones].
            last_power_masks = (
                power_ids[done].unsqueeze(0)
                == torch.arange(len(POWERS), device=power_ids.device).unsqueeze(1)
            ).float()
            last_power_rewards = (last_power_masks * last_rewards.unsqueeze(0)).sum(1)
            last_power_counts = last_power_masks.sum(1)
            counters["reward/last"].update(last_sum, last_count)
            for power, reward, counts in zip(
                POWERS, last_power_rewards.cpu(), last_power_counts.cpu()
            ):
                counters[f"reward/last_{power}"].update(reward, counts)
            # To match entropy loss we don't negate logprobs. So this
            # is an estimate of the negative entropy.
            counters["loss/entropy_sampled"].update(policy_action_logprobs.mean())

            # Measure off-policiness.
            counters["loss/rho"].update(vtrace_returns.rhos.sum(), vtrace_returns.rhos.numel())
            counters["loss/rhos_clipped"].update(
                vtrace_returns.clipped_rhos.sum(), vtrace_returns.clipped_rhos.numel()
            )

            bsz = len(rewards)
            counters["size/batch"].update(bsz)
            counters["size/episode"].update(bsz, last_count)

        with timings("sync"), torch.no_grad():
            if self.state.global_step % self.cfg.trainer.save_sync_checkpoint_every == 0:
                self.send_model_to_workers()

        # Doing outside of the context to capture the context's timing.
        for key, value in timings.items():
            counters[f"time/{key}"].update(value)