in reagent/training/slate_q_trainer.py [0:0]
def train_step_gen(self, training_batch: rlt.SlateQInput, batch_idx: int):
assert isinstance(
training_batch, rlt.SlateQInput
), f"learning input is a {type(training_batch)}"
reward = training_batch.reward
reward_mask = training_batch.reward_mask
discount_tensor = torch.full_like(reward, self.gamma)
# Adjust the discount factor by the time_diff if the discount_time_scale is provided,
# and the time_diff exists in the training_batch.
if self.discount_time_scale and training_batch.time_diff is not None:
discount_tensor = discount_tensor ** (
training_batch.time_diff / self.discount_time_scale
)
next_action = (
self._get_maxq_next_action(training_batch.next_state)
if self.rl_parameters.maxq_learning
else training_batch.next_action
)
terminal_mask = (training_batch.not_terminal.to(torch.bool) == False).squeeze(1)
next_action_docs = self._action_docs(
training_batch.next_state,
next_action,
terminal_mask=terminal_mask,
)
next_q_values = torch.sum(
self._get_unmasked_q_values(
self.q_network_target,
training_batch.next_state,
next_action_docs,
)
* self._get_docs_value(next_action_docs),
dim=1,
keepdim=True,
)
# If not single selection, divide max-Q by the actual slate size.
if not self.single_selection:
next_q_values = next_q_values / self._get_avg_by_slate_size(training_batch)
filtered_max_q_vals = next_q_values * training_batch.not_terminal.float()
target_q_values = reward + (discount_tensor * filtered_max_q_vals)
# Don't mask if not single selection
if self.single_selection:
target_q_values = target_q_values[reward_mask]
# Get Q-value of action taken
action_docs = self._action_docs(training_batch.state, training_batch.action)
q_values = self._get_unmasked_q_values(
self.q_network, training_batch.state, action_docs
)
if self.single_selection:
q_values = q_values[reward_mask]
all_action_scores = q_values.detach()
value_loss = F.mse_loss(q_values, target_q_values)
yield value_loss
if not self.single_selection:
all_action_scores = all_action_scores.sum(dim=1, keepdim=True)
# Logging at the end to schedule all the cuda operations first
self.reporter.log(
td_loss=value_loss,
model_values_on_logged_actions=all_action_scores,
)
# Use the soft update rule to update the target networks
result = self.soft_update_result()
self.log(
"td_loss", value_loss, prog_bar=True, batch_size=training_batch.batch_size()
)
yield result