in fairdiplomacy/selfplay/exploit.py [0:0]
def do_step_research(self, *, counters: collections.defaultdict, use_grad_clip: bool):
device = self.device
timings = TimingCtx()
with timings("data_gen"):
research_batch: ReSearchRolloutBatch = self.data_loader.get_batch() # type: ignore
do_search_policy_loss = (
self.is_policy_being_trained
and self.random.random() <= self.cfg.search_policy_update_prob
)
main_net_has_grads = value_net_has_grads = False
per_dataloader_batch_size = self.cfg.search_rollout.batch_size // self.ectx.ddp_world_size
with timings("to_cuda"):
rewards = research_batch.rewards.to(device)
if self.cfg.search_rollout.buffer.shuffle:
assert list(rewards.shape) == [
1,
self.cfg.search_rollout.chunk_length * per_dataloader_batch_size,
len(POWERS),
], rewards.shape
else:
assert list(rewards.shape) == [
self.cfg.search_rollout.chunk_length,
per_dataloader_batch_size,
len(POWERS),
], rewards.shape
obs = {k: v.to(device) for k, v in research_batch.observations.items()}
done = research_batch.done.to(device)
is_search_policy_valid = research_batch.is_search_policy_valid.to(device)
is_explore = research_batch.is_explore.to(device)
is_dead = (research_batch.scores < 1e-3).float().to(device)
targets = research_batch.targets.to(device)
is_move_phase = (research_batch.phase_type == ord("M")).to(device)
years = research_batch.years.to(device)
if is_explore.all():
logging.warning("Whole batch of explore!!! Skipping")
return
if do_search_policy_loss:
search_policy_probs, search_policy_orders, blueprint_probs = (
research_batch.search_policy_probs.to(device),
research_batch.search_policy_orders.long().to(device),
research_batch.blueprint_probs.to(device),
)
if self.cfg.search_ev_loss is not None:
search_policy_evs = research_batch.search_policy_evs.to(device)
else:
search_policy_probs = search_policy_orders = None
with timings("net"):
loss = torch.tensor(0.0, device=device)
losses = {}
if self.is_value_being_trained:
if self.state.value_net_state is not None:
value_net_has_grads = True
else:
main_net_has_grads = True
# Shape: [T, B, 7].
_, _, _, predicted_values = self.state.value_model(
**nest.map(lambda x: x.flatten(end_dim=1), obs),
temperature=1.0,
need_policy=False,
)
predicted_values = predicted_values.reshape(rewards.shape)
# Note, if you even change this, you have to propogate discounting
# to search_data_loader akin to data_loader.
assert self.cfg.discounting == 1.0, "Discounting is not supported for ReSearch"
critic_mses = torch.nn.functional.mse_loss(
targets, predicted_values, reduction="none"
)
losses["critic"] = critic_mses.mean()
loss = loss + self.cfg.critic_weight * losses["critic"]
if do_search_policy_loss:
main_net_has_grads = True
if self.cfg.search_ev_loss is not None:
policy_loss_targets = evs_to_policy(
search_policy_evs,
temperature=self.cfg.search_ev_loss.temperature,
use_softmax=self.cfg.search_ev_loss.use_softmax,
)
else:
policy_loss_targets = search_policy_probs
(
search_policy_loss,
search_policy_metrics,
) = compute_search_policy_cross_entropy_sampled(
self.state.model,
obs,
search_policy_orders,
policy_loss_targets,
blueprint_probs,
mask=is_search_policy_valid,
mse_loss=self.cfg.search_policy_mse,
mse_bp_normalized=self.cfg.search_policy_mse_bp_normalized,
mse_bp_upper_bound=self.cfg.search_policy_mse_bp_upper_bound,
is_move_phase=is_move_phase,
using_ddp=self.cfg.use_distributed_data_parallel,
)
for k, v in search_policy_metrics.items():
counters[k].update(v)
losses["search_policy"] = search_policy_loss
loss = loss + search_policy_loss * self.cfg.search_policy_weight
if (
self.cfg.sampled_entropy_weight is not None
and self.cfg.sampled_entropy_weight > 0.0
):
e_loss, e_mean = sample_and_compute_sampled_entropy_loss(self.state.model, obs)
losses["policy_entropy_loss"] = e_loss
losses["policy_entropy"] = e_mean
loss = loss + e_loss * self.cfg.sampled_entropy_weight
else:
search_policy_loss = None
self.state.optimizer.zero_grad()
if self.state.value_net_state is not None:
self.state.value_net_state.optimizer.zero_grad()
loss.backward()
if use_grad_clip:
if main_net_has_grads:
g_norm_tensor = clip_grad_norm_(
self.state.model.parameters(), self.cfg.optimizer.grad_clip
)
else:
g_norm_tensor = None
if value_net_has_grads:
value_grad_clip = (self.cfg.value_optimizer or self.cfg.optimizer).grad_clip
g_norm_value_tensor = clip_grad_norm_(
self.state.value_net_state.model.parameters(), value_grad_clip
)
else:
g_norm_value_tensor = None
if (
not self.cfg.trainer.max_updates
or self.state.global_step < self.cfg.trainer.max_updates
):
if main_net_has_grads:
self.state.net_state.optimizer.step()
if value_net_has_grads:
self.state.value_net_state.optimizer.step()
# Sync to make sure timing is correct.
loss.item()
with timings("metrics"), torch.no_grad():
last_count = done.long().sum()
time_bsz = rewards.shape[0] * rewards.shape[1]
if self.is_value_being_trained:
critic_end_mses = critic_mses[done].sum()
if use_grad_clip:
if g_norm_tensor is not None:
g_norm = g_norm_tensor.item()
counters["optim/grad_max"].update(g_norm)
counters["optim/grad_mean"].update(g_norm)
counters["optim/grad_clip_ratio"].update(
int(g_norm >= self.cfg.optimizer.grad_clip - 1e-5)
)
if g_norm_value_tensor is not None:
g_norm_value = g_norm_value_tensor.item()
counters["optim/grad_value_max"].update(g_norm_value)
counters["optim/grad_value_mean"].update(g_norm_value)
counters["optim/grad_value_clip_ratio"].update(
int(g_norm_value >= value_grad_clip - 1e-5)
)
for key, value in losses.items():
counters[f"loss/{key}"].update(value)
counters["loss/total"].update(loss.item())
explored_on_the_right = research_batch.explored_on_the_right
if self.is_value_being_trained:
counters["loss/critic_no_explore"].update(
critic_mses[~explored_on_the_right].sum(), explored_on_the_right.long().sum()
)
counters["loss/critic_last"].update(critic_end_mses, last_count)
counters["loss/is_explore"].update(is_explore.long().sum(), (1 - is_dead).sum())
counters["loss/offpolicy_part"].update(
explored_on_the_right.long().sum(), (1 - is_dead).sum()
)
counters["loss/is_search_policy_valid"].update(
is_search_policy_valid.float().sum(), is_search_policy_valid.numel()
)
if do_search_policy_loss:
# What's the entopy of the search policy.
counters["loss/entropy_search"].update(
compute_search_policy_entropy(
search_policy_orders, search_policy_probs, mask=is_search_policy_valid
)
)
counters["loss/entropy_search_moves"].update(
compute_search_policy_entropy(
search_policy_orders,
search_policy_probs,
mask=is_search_policy_valid * is_move_phase,
)
)
first_two_phases_move_mask = (
is_search_policy_valid * is_move_phase * (years == 1901)
)
if first_two_phases_move_mask.any():
counters["loss/entropy_search_moves_1901"].update(
compute_search_policy_entropy(
search_policy_orders,
search_policy_probs,
mask=first_two_phases_move_mask,
)
)
if self.cfg.search_ev_loss is not None:
counters["loss/entropy_search_from_evs"].update(
compute_search_policy_entropy(
search_policy_orders, policy_loss_targets, mask=is_search_policy_valid
)
)
phase_bp_sums = blueprint_probs.flatten(end_dim=1)[
is_search_policy_valid.flatten(end_dim=1)
]
phase_bp_sums = phase_bp_sums.sum(-1).view(-1)
phase_bp_sums = phase_bp_sums[phase_bp_sums > 1e-10]
counters["loss/bp_share"].update(phase_bp_sums.sum(), len(phase_bp_sums))
counters["reward/mean"].update(rewards.sum(), time_bsz)
# Rewards at the end of episodes. We precompute everything
# before adding to counters to pipeline things when
# possible.
last_rewards = rewards[done]
last_sum = last_rewards.sum()
if self.is_value_being_trained:
# Mean predicted value for dead powers.
counters["value/mean_dead"].update(
(predicted_values * is_dead).sum(), is_dead.sum()
)
counters["reward/last"].update(last_sum, last_count)
counters["reward_solo/last"].update((last_rewards > 0.9).float().sum(), last_count)
for i, power in enumerate(POWERS):
power_rewards = last_rewards[..., i]
counters[f"reward/last_{power}"].update(power_rewards.sum(), last_count)
counters[f"reward_solo/last_{power}"].update(
(power_rewards > 0.9).sum(), last_count
)
# Average batch size in phases, i.e., T * B.
counters["size/batch"].update(time_bsz * self.ectx.ddp_world_size)
if self.ectx.ddp_world_size > 1:
counters["size/batch_local"].update(time_bsz)
# Average number of phases per episode.
counters["size/episode"].update(time_bsz, last_count)
# Average number of move phases per episode.
counters["size/episode_moves"].update(is_move_phase.float().sum(), last_count)
with timings("sync"), torch.no_grad():
if self.state.global_step % self.cfg.trainer.save_sync_checkpoint_every == 0:
self.send_model_to_workers()
# Doing outside of the context to capture the context's timing.
for key, value in timings.items():
counters[f"time/{key}"].update(value)