def train_step_gen()

in reagent/training/qrdqn_trainer.py [0:0]


    def train_step_gen(self, training_batch: rlt.DiscreteDqnInput, batch_idx: int):
        self._check_input(training_batch)

        rewards = self.boost_rewards(training_batch.reward, training_batch.action)
        discount_tensor = torch.full_like(rewards, self.gamma)
        possible_next_actions_mask = training_batch.possible_next_actions_mask.float()
        possible_actions_mask = training_batch.possible_actions_mask.float()

        not_done_mask = training_batch.not_terminal.float()

        if self.use_seq_num_diff_as_time_diff:
            assert self.multi_steps is None
            discount_tensor = torch.pow(self.gamma, training_batch.time_diff.float())
        if self.multi_steps is not None:
            assert training_batch.step is not None
            discount_tensor = torch.pow(self.gamma, training_batch.step.float())

        next_qf = self.q_network_target(training_batch.next_state)

        if self.maxq_learning:
            # Select distribution corresponding to max valued action
            next_q_values = (
                self.q_network(training_batch.next_state)
                if self.double_q_learning
                else next_qf
            ).mean(dim=2)
            next_action = self.argmax_with_mask(
                next_q_values, possible_next_actions_mask
            )
            next_qf = next_qf[range(rewards.shape[0]), next_action.reshape(-1)]
        else:
            next_qf = (next_qf * training_batch.next_action.unsqueeze(-1)).sum(1)

        # Build target distribution
        target_Q = rewards + discount_tensor * not_done_mask * next_qf

        current_qf = self.q_network(training_batch.state)

        # for reporting only
        all_q_values = current_qf.mean(2).detach()

        current_qf = (current_qf * training_batch.action.unsqueeze(-1)).sum(1)

        # (batch, atoms) -> (atoms, batch, 1) -> (atoms, batch, atoms)
        td = target_Q.t().unsqueeze(-1) - current_qf
        loss = (
            self.huber(td) * (self.quantiles - (td.detach() < 0).float()).abs()
        ).mean()

        yield loss
        # pyre-fixme[16]: `DQNTrainer` has no attribute `loss`.
        self.loss = loss.detach()

        # Get Q-values of next states, used in computing cpe
        all_next_action_scores = (
            self.q_network(training_batch.next_state).detach().mean(dim=2)
        )

        logged_action_idxs = torch.argmax(training_batch.action, dim=1, keepdim=True)
        yield from self._calculate_cpes(
            training_batch,
            training_batch.state,
            training_batch.next_state,
            all_q_values,
            all_next_action_scores,
            logged_action_idxs,
            discount_tensor,
            not_done_mask,
        )

        model_action_idxs = self.argmax_with_mask(
            all_q_values,
            possible_actions_mask if self.maxq_learning else training_batch.action,
        )

        self.reporter.log(
            td_loss=loss,
            logged_actions=logged_action_idxs,
            logged_propensities=training_batch.extras.action_probability,
            logged_rewards=rewards,
            logged_values=None,  # Compute at end of each epoch for CPE
            model_values=all_q_values,
            model_values_on_logged_actions=None,  # Compute at end of each epoch for CPE
            model_action_idxs=model_action_idxs,
        )

        yield self.soft_update_result()