in reagent/training/c51_trainer.py [0:0]
def train_step_gen(self, training_batch: rlt.DiscreteDqnInput, batch_idx: int):
rewards = self.boost_rewards(training_batch.reward, training_batch.action)
discount_tensor = torch.full_like(rewards, self.gamma)
possible_next_actions_mask = training_batch.possible_next_actions_mask.float()
possible_actions_mask = training_batch.possible_actions_mask.float()
not_terminal = training_batch.not_terminal.float()
if self.use_seq_num_diff_as_time_diff:
assert self.multi_steps is None
discount_tensor = torch.pow(self.gamma, training_batch.time_diff.float())
if self.multi_steps is not None:
assert training_batch.step is not None
discount_tensor = torch.pow(self.gamma, training_batch.step.float())
next_dist = self.q_network_target.log_dist(training_batch.next_state).exp()
if self.maxq_learning:
# Select distribution corresponding to max valued action
if self.double_q_learning:
next_q_values = (
self.q_network.log_dist(training_batch.next_state).exp()
* self.support
).sum(2)
else:
next_q_values = (next_dist * self.support).sum(2)
next_action = self.argmax_with_mask(
next_q_values, possible_next_actions_mask
)
next_dist = next_dist[range(rewards.shape[0]), next_action.reshape(-1)]
else:
next_dist = (next_dist * training_batch.next_action.unsqueeze(-1)).sum(1)
# Build target distribution
target_Q = rewards + discount_tensor * not_terminal * self.support
target_Q = target_Q.clamp(self.qmin, self.qmax)
# rescale to indicies [0, 1, ..., N-1]
b = (target_Q - self.qmin) / self.scale_support
lo = b.floor().to(torch.int64)
up = b.ceil().to(torch.int64)
# handle corner cases of l == b == u
# without the following, it would give 0 signal, whereas we want
# m to add p(s_t+n, a*) to index l == b == u.
# So we artificially adjust l and u.
# (1) If 0 < l == u < N-1, we make l = l-1, so b-l = 1
# (2) If 0 == l == u, we make u = 1, so u-b=1
# (3) If l == u == N-1, we make l = N-2, so b-1 = 1
# This first line handles (1) and (3).
lo[(up > 0) * (lo == up)] -= 1
# Note: l has already changed, so the only way l == u is possible is
# if u == 0, in which case we let u = 1
# I don't even think we need the first condition in the next line
up[(lo < (self.num_atoms - 1)) * (lo == up)] += 1
# distribute the probabilities
# m_l = m_l + p(s_t+n, a*)(u - b)
# m_u = m_u + p(s_t+n, a*)(b - l)
m = torch.zeros_like(next_dist)
# pyre-fixme[16]: `Tensor` has no attribute `scatter_add_`.
m.scatter_add_(dim=1, index=lo, src=next_dist * (up.float() - b))
m.scatter_add_(dim=1, index=up, src=next_dist * (b - lo.float()))
log_dist = self.q_network.log_dist(training_batch.state)
# for reporting only
all_q_values = (log_dist.exp() * self.support).sum(2).detach()
model_action_idxs = self.argmax_with_mask(
all_q_values,
possible_actions_mask if self.maxq_learning else training_batch.action,
)
log_dist = (log_dist * training_batch.action.unsqueeze(-1)).sum(1)
loss = -(m * log_dist).sum(1).mean()
if batch_idx % self.trainer.log_every_n_steps == 0:
self.reporter.log(
td_loss=loss,
logged_actions=torch.argmax(training_batch.action, dim=1, keepdim=True),
logged_propensities=training_batch.extras.action_probability,
logged_rewards=rewards,
model_values=all_q_values,
model_action_idxs=model_action_idxs,
)
self.log(
"td_loss", loss, prog_bar=True, batch_size=training_batch.batch_size()
)
yield loss
result = self.soft_update_result()
yield result