in reagent/training/discrete_crr_trainer.py [0:0]
def train_step_gen(self, training_batch: rlt.DiscreteDqnInput, batch_idx: int):
"""
IMPORTANT: the input action here is preprocessed according to the
training_batch type, which in this case is DiscreteDqnInput. Hence,
the preprocessor in the DiscreteDqnInputMaker class in the
trainer_preprocessor.py is used, which converts acion taken to a
one-hot representation.
"""
self._check_input(training_batch)
state = training_batch.state
action = training_batch.action
next_state = training_batch.next_state
not_terminal = training_batch.not_terminal
rewards = self.boost_rewards(training_batch.reward, training_batch.action)
# Remember: training_batch.action is in the one-hot format
logged_action_idxs = torch.argmax(action, dim=1, keepdim=True)
discount_tensor = torch.full_like(rewards, self.gamma)
next_q_values = self.q1_network_target(next_state)
target_q_values = self.compute_target_q_values(
next_state, rewards, not_terminal, next_q_values
)
q1_loss = self.compute_td_loss(self.q1_network, state, action, target_q_values)
# Show td_loss on the progress bar and in tensorboard graphs:
self.log(
"td_loss", q1_loss, prog_bar=True, batch_size=training_batch.batch_size()
)
yield q1_loss
if self.q2_network:
q2_loss = self.compute_td_loss(
self.q2_network, state, action, target_q_values
)
yield q2_loss
all_q_values = self.q1_network(state) # Q-values of all actions
# Note: action_dim (the length of each row of the actor_action
# matrix obtained below) is assumed to be > 1.
all_action_scores = self.actor_network(state).action
logged_action_probs = training_batch.extras.action_probability
actor_loss_without_reg, actor_loss = self.compute_actor_loss(
batch_idx, action, logged_action_probs, all_q_values, all_action_scores
)
# self.reporter.log(
# actor_loss=actor_loss,
# actor_q1_value=actor_q1_values,
# )
# Show actor_loss on the progress bar and also in Tensorboard graphs
self.log(
"actor_loss_without_reg",
actor_loss_without_reg,
prog_bar=True,
batch_size=training_batch.batch_size(),
)
self.log(
"actor_loss",
actor_loss,
prog_bar=True,
batch_size=training_batch.batch_size(),
)
yield actor_loss
yield from self._calculate_cpes(
training_batch,
state,
next_state,
all_action_scores,
next_q_values.detach(),
logged_action_idxs,
discount_tensor,
not_terminal.float(),
)
# TODO: rename underlying function to get_max_possible_values_and_idxs
model_action_idxs = self.get_max_q_values(
all_action_scores,
training_batch.possible_actions_mask if self.maxq_learning else action,
)[1]
self.reporter.log(
logged_actions=logged_action_idxs,
td_loss=q1_loss,
logged_propensities=training_batch.extras.action_probability,
logged_rewards=rewards,
model_values=all_action_scores,
model_action_idxs=model_action_idxs,
)
# Use the soft update rule to update the target networks.
# Note: this yield has to be the last one, since SoftUpdate is the last
# optimizer added in the configure_optimizers() function.
result = self.soft_update_result()
yield result