in reagent/training/qrdqn_trainer.py [0:0]
def train_step_gen(self, training_batch: rlt.DiscreteDqnInput, batch_idx: int):
self._check_input(training_batch)
rewards = self.boost_rewards(training_batch.reward, training_batch.action)
discount_tensor = torch.full_like(rewards, self.gamma)
possible_next_actions_mask = training_batch.possible_next_actions_mask.float()
possible_actions_mask = training_batch.possible_actions_mask.float()
not_done_mask = training_batch.not_terminal.float()
if self.use_seq_num_diff_as_time_diff:
assert self.multi_steps is None
discount_tensor = torch.pow(self.gamma, training_batch.time_diff.float())
if self.multi_steps is not None:
assert training_batch.step is not None
discount_tensor = torch.pow(self.gamma, training_batch.step.float())
next_qf = self.q_network_target(training_batch.next_state)
if self.maxq_learning:
# Select distribution corresponding to max valued action
next_q_values = (
self.q_network(training_batch.next_state)
if self.double_q_learning
else next_qf
).mean(dim=2)
next_action = self.argmax_with_mask(
next_q_values, possible_next_actions_mask
)
next_qf = next_qf[range(rewards.shape[0]), next_action.reshape(-1)]
else:
next_qf = (next_qf * training_batch.next_action.unsqueeze(-1)).sum(1)
# Build target distribution
target_Q = rewards + discount_tensor * not_done_mask * next_qf
current_qf = self.q_network(training_batch.state)
# for reporting only
all_q_values = current_qf.mean(2).detach()
current_qf = (current_qf * training_batch.action.unsqueeze(-1)).sum(1)
# (batch, atoms) -> (atoms, batch, 1) -> (atoms, batch, atoms)
td = target_Q.t().unsqueeze(-1) - current_qf
loss = (
self.huber(td) * (self.quantiles - (td.detach() < 0).float()).abs()
).mean()
yield loss
# pyre-fixme[16]: `DQNTrainer` has no attribute `loss`.
self.loss = loss.detach()
# Get Q-values of next states, used in computing cpe
all_next_action_scores = (
self.q_network(training_batch.next_state).detach().mean(dim=2)
)
logged_action_idxs = torch.argmax(training_batch.action, dim=1, keepdim=True)
yield from self._calculate_cpes(
training_batch,
training_batch.state,
training_batch.next_state,
all_q_values,
all_next_action_scores,
logged_action_idxs,
discount_tensor,
not_done_mask,
)
model_action_idxs = self.argmax_with_mask(
all_q_values,
possible_actions_mask if self.maxq_learning else training_batch.action,
)
self.reporter.log(
td_loss=loss,
logged_actions=logged_action_idxs,
logged_propensities=training_batch.extras.action_probability,
logged_rewards=rewards,
logged_values=None, # Compute at end of each epoch for CPE
model_values=all_q_values,
model_values_on_logged_actions=None, # Compute at end of each epoch for CPE
model_action_idxs=model_action_idxs,
)
yield self.soft_update_result()