in fairdiplomacy/selfplay/exploit.py [0:0]
def do_step_policy_gradient(self, *, counters: collections.defaultdict, use_grad_clip: bool):
device = self.device
timings = TimingCtx()
with timings("data_gen"):
(
(power_ids, obs, rewards, actions, behavior_action_logprobs, done),
rollout_scores_per_power,
) = self.data_loader.get_batch()
with timings("to_cuda"):
actions = actions.to(device)
rewards = rewards.to(device)
power_ids = power_ids.to(device)
obs = {k: v.to(device) for k, v in obs.items()}
cand_actions = obs.pop("cand_indices")
behavior_action_logprobs = behavior_action_logprobs.to(device)
done = done.to(device)
with timings("net"):
# Shape: _, [B, 17], [B, S, 469], [B, 7].
# policy_cand_actions has the same information as actions,
# but uses local indices to match policy logits.
assert EOS_IDX == -1, "Rewrite the code to remove the assumption"
_, _, policy_logits, sc_values = self.state.model(
**obs,
temperature=1.0,
teacher_force_orders=actions.clamp(0), # EOS_IDX = -1 -> 0
x_power=power_ids.view(-1, 1).repeat(1, MAX_SEQ_LEN),
)
cand_actions = cand_actions[:, : policy_logits.shape[1]]
# Shape: [B].
sc_values = sc_values.gather(1, power_ids.unsqueeze(1)).squeeze(1)
# Removing absolute order ids to not use them by accident.
# Will use relative order ids (cand_actions) from now on.
del actions
if self.cfg.rollout.do_not_split_rollouts:
# Asssumes that episode actually ends.
bootstrap_value = torch.zeros_like(sc_values[-1])
else:
# Reducing batch size by one. Deleting things that are
# too lazy to adjsut to avoid artifacts.
bootstrap_value = sc_values[-1].detach()
sc_values = sc_values[:-1]
cand_actions = cand_actions[:-1]
policy_logits = policy_logits[:-1]
rewards = rewards[:-1]
power_ids = power_ids[:-1]
del obs
behavior_action_logprobs = behavior_action_logprobs[:-1]
done = done[:-1]
# Shape: [B].
discounts = (~done).float() * self.cfg.discounting
# Shape: [B, 17].
mask = (cand_actions != EOS_IDX).float()
# Shape: [B].
policy_action_logprobs = order_logits_to_action_logprobs(
policy_logits, cand_actions, mask
)
vtrace_returns = vtrace_from_logprobs_no_batch(
log_rhos=policy_action_logprobs - behavior_action_logprobs,
discounts=discounts,
rewards=rewards,
values=sc_values,
bootstrap_value=bootstrap_value,
)
critic_mses = 0.5 * ((vtrace_returns.vs.detach() - sc_values) ** 2)
losses = dict(
actor=compute_policy_gradient_loss(
policy_action_logprobs, vtrace_returns.pg_advantages
),
critic=critic_mses.mean(),
# TODO(akhti): it's incorrect to apply this to
# per-position order distribution instead of action
# distribution.
entropy=compute_entropy_loss(policy_logits, mask),
)
loss = (
losses["actor"]
+ self.cfg.critic_weight * losses["critic"]
+ self.cfg.entropy_weight * losses["entropy"]
)
if self.cfg.sampled_entropy_weight:
loss = loss + self.cfg.sampled_entropy_weight * compute_sampled_entropy_loss(
policy_action_logprobs
)
self.state.optimizer.zero_grad()
loss.backward()
if use_grad_clip:
g_norm_tensor = clip_grad_norm_(
self.state.model.parameters(), self.cfg.optimizer.grad_clip
)
if (
not self.cfg.trainer.max_updates
or self.state.global_step < self.cfg.trainer.max_updates
):
self.state.optimizer.step()
# Sync to make sure timing is correct.
loss.item()
with timings("metrics"), torch.no_grad():
last_count = done.long().sum()
critic_end_mses = critic_mses[done].sum()
if use_grad_clip:
g_norm = g_norm_tensor.item()
counters["optim/grad_max"].update(g_norm)
counters["optim/grad_mean"].update(g_norm)
counters["optim/grad_clip_ratio"].update(
int(g_norm >= self.cfg.optimizer.grad_clip - 1e-5)
)
for key, value in losses.items():
counters[f"loss/{key}"].update(value)
counters["loss/total"].update(loss.item())
for power_id, rollout_scores in rollout_scores_per_power.items():
prefix = f"score_{POWERS[power_id]}" if power_id is not None else "score"
for key, value in rollout_scores.items():
if key != "num_games":
counters[f"{prefix}/{key}"].update(value, rollout_scores["num_games"])
else:
counters[f"{prefix}/{key}"].update(value)
counters["loss/critic_last"].update(critic_end_mses, last_count)
counters["reward/mean"].update(rewards.sum(), len(rewards))
# Rewards at the end of episodes. We precompute everything
# before adding to counters to pipeline things when
# possible.
last_rewards = rewards[done]
last_sum = last_rewards.sum()
# tensor [num_powers, num_dones].
last_power_masks = (
power_ids[done].unsqueeze(0)
== torch.arange(len(POWERS), device=power_ids.device).unsqueeze(1)
).float()
last_power_rewards = (last_power_masks * last_rewards.unsqueeze(0)).sum(1)
last_power_counts = last_power_masks.sum(1)
counters["reward/last"].update(last_sum, last_count)
for power, reward, counts in zip(
POWERS, last_power_rewards.cpu(), last_power_counts.cpu()
):
counters[f"reward/last_{power}"].update(reward, counts)
# To match entropy loss we don't negate logprobs. So this
# is an estimate of the negative entropy.
counters["loss/entropy_sampled"].update(policy_action_logprobs.mean())
# Measure off-policiness.
counters["loss/rho"].update(vtrace_returns.rhos.sum(), vtrace_returns.rhos.numel())
counters["loss/rhos_clipped"].update(
vtrace_returns.clipped_rhos.sum(), vtrace_returns.clipped_rhos.numel()
)
bsz = len(rewards)
counters["size/batch"].update(bsz)
counters["size/episode"].update(bsz, last_count)
with timings("sync"), torch.no_grad():
if self.state.global_step % self.cfg.trainer.save_sync_checkpoint_every == 0:
self.send_model_to_workers()
# Doing outside of the context to capture the context's timing.
for key, value in timings.items():
counters[f"time/{key}"].update(value)