in reagent/training/parametric_dqn_trainer.py [0:0]
def train_step_gen(self, training_batch: rlt.ParametricDqnInput, batch_idx: int):
reward = training_batch.reward
not_terminal = training_batch.not_terminal.float()
discount_tensor = torch.full_like(reward, self.gamma)
if self.use_seq_num_diff_as_time_diff:
assert self.multi_steps is None
discount_tensor = torch.pow(self.gamma, training_batch.time_diff.float())
if self.multi_steps is not None:
# pyre-fixme[16]: Optional type has no attribute `float`.
discount_tensor = torch.pow(self.gamma, training_batch.step.float())
if self.maxq_learning:
# Assuming actions are parametrized in a k-dimensional space
# tiled_state = (batch_size * max_num_action, state_dim)
# possible_actions = (batch_size* max_num_action, k)
# possible_actions_mask = (batch_size, max_num_action)
product = training_batch.possible_next_actions.float_features.shape[0]
batch_size = training_batch.possible_actions_mask.shape[0]
assert product % batch_size == 0, (
f"batch_size * max_num_action {product} is "
f"not divisible by batch_size {batch_size}"
)
max_num_action = product // batch_size
tiled_next_state = training_batch.next_state.get_tiled_batch(max_num_action)
(
all_next_q_values,
all_next_q_values_target,
) = self.get_detached_model_outputs(
tiled_next_state, training_batch.possible_next_actions
)
# Compute max a' Q(s', a') over all possible actions using target network
next_q_values, _ = self.get_max_q_values_with_target(
all_next_q_values,
all_next_q_values_target,
training_batch.possible_next_actions_mask.float(),
)
assert (
len(next_q_values.shape) == 2 and next_q_values.shape[1] == 1
), f"{next_q_values.shape}"
else:
# SARSA (Use the target network)
_, next_q_values = self.get_detached_model_outputs(
training_batch.next_state, training_batch.next_action
)
assert (
len(next_q_values.shape) == 2 and next_q_values.shape[1] == 1
), f"{next_q_values.shape}"
target_q_values = reward + not_terminal * discount_tensor * next_q_values
assert (
target_q_values.shape[-1] == 1
), f"{target_q_values.shape} doesn't end with 1"
# Get Q-value of action taken
q_values = self.q_network(training_batch.state, training_batch.action)
assert (
target_q_values.shape == q_values.shape
), f"{target_q_values.shape} != {q_values.shape}."
td_loss = self.q_network_loss(q_values, target_q_values)
yield td_loss
# pyre-fixme[16]: Optional type has no attribute `metrics`.
if training_batch.extras.metrics is not None:
metrics_reward_concat_real_vals = torch.cat(
(reward, training_batch.extras.metrics), dim=1
)
else:
metrics_reward_concat_real_vals = reward
# get reward estimates
reward_estimates = self.reward_network(
training_batch.state, training_batch.action
)
reward_loss = F.mse_loss(
reward_estimates.squeeze(-1),
metrics_reward_concat_real_vals.squeeze(-1),
)
yield reward_loss
self.reporter.log(
td_loss=td_loss.detach().cpu(),
reward_loss=reward_loss.detach().cpu(),
logged_rewards=reward,
model_values_on_logged_actions=q_values.detach().cpu(),
)
# Use the soft update rule to update target network
yield self.soft_update_result()