def train_step_gen()

in reagent/training/parametric_dqn_trainer.py [0:0]


    def train_step_gen(self, training_batch: rlt.ParametricDqnInput, batch_idx: int):
        reward = training_batch.reward
        not_terminal = training_batch.not_terminal.float()
        discount_tensor = torch.full_like(reward, self.gamma)
        if self.use_seq_num_diff_as_time_diff:
            assert self.multi_steps is None
            discount_tensor = torch.pow(self.gamma, training_batch.time_diff.float())
        if self.multi_steps is not None:
            # pyre-fixme[16]: Optional type has no attribute `float`.
            discount_tensor = torch.pow(self.gamma, training_batch.step.float())

        if self.maxq_learning:
            # Assuming actions are parametrized in a k-dimensional space
            # tiled_state = (batch_size * max_num_action, state_dim)
            # possible_actions = (batch_size* max_num_action, k)
            # possible_actions_mask = (batch_size, max_num_action)
            product = training_batch.possible_next_actions.float_features.shape[0]
            batch_size = training_batch.possible_actions_mask.shape[0]
            assert product % batch_size == 0, (
                f"batch_size * max_num_action {product} is "
                f"not divisible by batch_size {batch_size}"
            )
            max_num_action = product // batch_size
            tiled_next_state = training_batch.next_state.get_tiled_batch(max_num_action)
            (
                all_next_q_values,
                all_next_q_values_target,
            ) = self.get_detached_model_outputs(
                tiled_next_state, training_batch.possible_next_actions
            )
            # Compute max a' Q(s', a') over all possible actions using target network
            next_q_values, _ = self.get_max_q_values_with_target(
                all_next_q_values,
                all_next_q_values_target,
                training_batch.possible_next_actions_mask.float(),
            )
            assert (
                len(next_q_values.shape) == 2 and next_q_values.shape[1] == 1
            ), f"{next_q_values.shape}"

        else:
            # SARSA (Use the target network)
            _, next_q_values = self.get_detached_model_outputs(
                training_batch.next_state, training_batch.next_action
            )
            assert (
                len(next_q_values.shape) == 2 and next_q_values.shape[1] == 1
            ), f"{next_q_values.shape}"

        target_q_values = reward + not_terminal * discount_tensor * next_q_values
        assert (
            target_q_values.shape[-1] == 1
        ), f"{target_q_values.shape} doesn't end with 1"

        # Get Q-value of action taken
        q_values = self.q_network(training_batch.state, training_batch.action)
        assert (
            target_q_values.shape == q_values.shape
        ), f"{target_q_values.shape} != {q_values.shape}."
        td_loss = self.q_network_loss(q_values, target_q_values)
        yield td_loss

        # pyre-fixme[16]: Optional type has no attribute `metrics`.
        if training_batch.extras.metrics is not None:
            metrics_reward_concat_real_vals = torch.cat(
                (reward, training_batch.extras.metrics), dim=1
            )
        else:
            metrics_reward_concat_real_vals = reward

        # get reward estimates
        reward_estimates = self.reward_network(
            training_batch.state, training_batch.action
        )
        reward_loss = F.mse_loss(
            reward_estimates.squeeze(-1),
            metrics_reward_concat_real_vals.squeeze(-1),
        )
        yield reward_loss

        self.reporter.log(
            td_loss=td_loss.detach().cpu(),
            reward_loss=reward_loss.detach().cpu(),
            logged_rewards=reward,
            model_values_on_logged_actions=q_values.detach().cpu(),
        )

        # Use the soft update rule to update target network
        yield self.soft_update_result()