def train_step_gen()

in reagent/training/c51_trainer.py [0:0]


    def train_step_gen(self, training_batch: rlt.DiscreteDqnInput, batch_idx: int):
        rewards = self.boost_rewards(training_batch.reward, training_batch.action)
        discount_tensor = torch.full_like(rewards, self.gamma)
        possible_next_actions_mask = training_batch.possible_next_actions_mask.float()
        possible_actions_mask = training_batch.possible_actions_mask.float()

        not_terminal = training_batch.not_terminal.float()

        if self.use_seq_num_diff_as_time_diff:
            assert self.multi_steps is None
            discount_tensor = torch.pow(self.gamma, training_batch.time_diff.float())
        if self.multi_steps is not None:
            assert training_batch.step is not None
            discount_tensor = torch.pow(self.gamma, training_batch.step.float())

        next_dist = self.q_network_target.log_dist(training_batch.next_state).exp()

        if self.maxq_learning:
            # Select distribution corresponding to max valued action
            if self.double_q_learning:
                next_q_values = (
                    self.q_network.log_dist(training_batch.next_state).exp()
                    * self.support
                ).sum(2)
            else:
                next_q_values = (next_dist * self.support).sum(2)

            next_action = self.argmax_with_mask(
                next_q_values, possible_next_actions_mask
            )
            next_dist = next_dist[range(rewards.shape[0]), next_action.reshape(-1)]
        else:
            next_dist = (next_dist * training_batch.next_action.unsqueeze(-1)).sum(1)

        # Build target distribution
        target_Q = rewards + discount_tensor * not_terminal * self.support
        target_Q = target_Q.clamp(self.qmin, self.qmax)

        # rescale to indicies [0, 1, ..., N-1]
        b = (target_Q - self.qmin) / self.scale_support
        lo = b.floor().to(torch.int64)
        up = b.ceil().to(torch.int64)

        # handle corner cases of l == b == u
        # without the following, it would give 0 signal, whereas we want
        # m to add p(s_t+n, a*) to index l == b == u.
        # So we artificially adjust l and u.
        # (1) If 0 < l == u < N-1, we make l = l-1, so b-l = 1
        # (2) If 0 == l == u, we make u = 1, so u-b=1
        # (3) If l == u == N-1, we make l = N-2, so b-1 = 1
        # This first line handles (1) and (3).
        lo[(up > 0) * (lo == up)] -= 1
        # Note: l has already changed, so the only way l == u is possible is
        # if u == 0, in which case we let u = 1
        # I don't even think we need the first condition in the next line
        up[(lo < (self.num_atoms - 1)) * (lo == up)] += 1

        # distribute the probabilities
        # m_l = m_l + p(s_t+n, a*)(u - b)
        # m_u = m_u + p(s_t+n, a*)(b - l)
        m = torch.zeros_like(next_dist)
        # pyre-fixme[16]: `Tensor` has no attribute `scatter_add_`.
        m.scatter_add_(dim=1, index=lo, src=next_dist * (up.float() - b))
        m.scatter_add_(dim=1, index=up, src=next_dist * (b - lo.float()))
        log_dist = self.q_network.log_dist(training_batch.state)

        # for reporting only
        all_q_values = (log_dist.exp() * self.support).sum(2).detach()
        model_action_idxs = self.argmax_with_mask(
            all_q_values,
            possible_actions_mask if self.maxq_learning else training_batch.action,
        )

        log_dist = (log_dist * training_batch.action.unsqueeze(-1)).sum(1)

        loss = -(m * log_dist).sum(1).mean()

        if batch_idx % self.trainer.log_every_n_steps == 0:
            self.reporter.log(
                td_loss=loss,
                logged_actions=torch.argmax(training_batch.action, dim=1, keepdim=True),
                logged_propensities=training_batch.extras.action_probability,
                logged_rewards=rewards,
                model_values=all_q_values,
                model_action_idxs=model_action_idxs,
            )
            self.log(
                "td_loss", loss, prog_bar=True, batch_size=training_batch.batch_size()
            )

        yield loss
        result = self.soft_update_result()
        yield result