in reagent/training/ranking/seq2slate_trainer.py [0:0]
def training_step(self, batch: rlt.PreprocessedRankingInput, batch_idx: int):
assert type(batch) is rlt.PreprocessedRankingInput
batch_size = batch.state.float_features.shape[0]
reward = batch.slate_reward
assert reward is not None
optimizers = self.optimizers()
if self.baseline_net:
assert len(optimizers) == 2
baseline_opt = optimizers[1]
else:
assert len(optimizers) == 1
rl_opt = optimizers[0]
if self.baseline_net:
# Train baseline
b = self.baseline_net(batch)
baseline_loss = 1.0 / batch_size * torch.sum((b - reward) ** 2)
baseline_opt.zero_grad()
self.manual_backward(baseline_loss)
baseline_opt.step()
else:
b = torch.zeros_like(reward)
baseline_loss = torch.zeros(1)
# Train Seq2Slate using REINFORCE
# log probs of tgt seqs
model_propensities = torch.exp(
self.seq2slate_net(
batch, mode=Seq2SlateMode.PER_SEQ_LOG_PROB_MODE
).log_probs
)
b = b.detach()
assert (
b.shape == reward.shape == model_propensities.shape
), f"{b.shape} {reward.shape} {model_propensities.shape}"
impt_smpl, clamped_impt_smpl = self._compute_impt_smpl(
model_propensities, batch.tgt_out_probs
)
assert (
impt_smpl.shape == clamped_impt_smpl.shape == reward.shape
), f"{impt_smpl.shape} {clamped_impt_smpl.shape} {reward.shape}"
# gradient is only w.r.t model_propensities
assert (
not reward.requires_grad
# pyre-fixme[16]: `Optional` has no attribute `requires_grad`.
and not batch.tgt_out_probs.requires_grad
and impt_smpl.requires_grad
and clamped_impt_smpl.requires_grad
and not b.requires_grad
)
# add negative sign because we take gradient descent but we want to
# maximize rewards
batch_obj_loss = -clamped_impt_smpl * (reward - b)
obj_loss = torch.mean(batch_obj_loss)
# condition to perform policy gradient update:
# 1. no baseline
# 2. or baseline is present and it passes the warm up stage
# 3. the last policy gradient was performed policy_gradient_interval minibatches ago
if (
self.baseline_net is None
or (self.all_batches_processed + 1) >= self.baseline_warmup_num_batches
):
self.manual_backward(obj_loss)
if (self.all_batches_processed + 1) % self.policy_gradient_interval == 0:
rl_opt.step()
rl_opt.zero_grad()
else:
logger.info("Not update RL model because now is baseline warmup phase")
ips_loss = torch.mean(-impt_smpl * reward).cpu().detach().numpy()
clamped_ips_loss = (
torch.mean(-clamped_impt_smpl * reward).cpu().detach().numpy()
)
baseline_loss = baseline_loss.detach().cpu().numpy().item()
advantage = (reward - b).detach().cpu().numpy()
logged_slate_rank_probs = model_propensities.detach().cpu().numpy()
if (self.all_batches_processed + 1) % self.print_interval == 0:
logger.info(
"{} batch: ips_loss={}, clamped_ips_loss={}, baseline_loss={}, max_ips={}, mean_ips={}, grad_update={}".format(
self.all_batches_processed + 1,
ips_loss,
clamped_ips_loss,
baseline_loss,
torch.max(impt_smpl),
torch.mean(impt_smpl),
(self.all_batches_processed + 1) % self.policy_gradient_interval
== 0,
)
)
self.reporter.log(
train_ips_score=torch.tensor(ips_loss).reshape(1),
train_clamped_ips_score=torch.tensor(clamped_ips_loss).reshape(1),
train_baseline_loss=torch.tensor(baseline_loss).reshape(1),
train_logged_slate_rank_probs=torch.FloatTensor(logged_slate_rank_probs),
train_ips_ratio=impt_smpl,
train_clamped_ips_ratio=clamped_impt_smpl,
train_advantages=advantage,
)