def _simulated_training_input()

in reagent/training/ranking/seq2slate_sim_trainer.py [0:0]


    def _simulated_training_input(self, training_input: rlt.PreprocessedRankingInput):
        device = training_input.state.float_features.device
        # precision error may cause invalid actions
        valid_output = False
        while not valid_output:
            rank_output = self.seq2slate_net(
                training_input,
                mode=Seq2SlateMode.RANK_MODE,
                tgt_seq_len=self.seq2slate_net.max_tgt_seq_len,
                greedy=False,
            )
            model_propensities = rank_output.ranked_per_seq_probs
            model_actions_with_offset = rank_output.ranked_tgt_out_idx
            model_actions = model_actions_with_offset - 2
            if torch.all(model_actions >= 0):
                valid_output = True

        batch_size = model_actions_with_offset.shape[0]
        simulated_slate_features = gather(
            training_input.src_seq.float_features,
            # pyre-fixme[61]: `model_actions` may not be initialized here.
            model_actions,
        )

        if not self.reward_name_and_net:
            use_gpu = True if device == torch.device("cuda") else False
            self.reward_name_and_net = nn.ModuleDict(
                _load_reward_net(self.sim_param.reward_name_path, use_gpu)
            )

        sim_slate_reward = torch.zeros(batch_size, 1, device=device)
        for name, reward_net in self.reward_name_and_net.items():
            weight = self.sim_param.reward_name_weight[name]
            power = self.sim_param.reward_name_power[name]
            sr = reward_net(
                training_input.state.float_features,
                training_input.src_seq.float_features,
                simulated_slate_features,
                training_input.src_src_mask,
                model_actions_with_offset,
            ).detach()
            assert sr.ndim == 2, f"Slate reward {name} output should be 2-D tensor"
            sim_slate_reward += weight * (sr ** power)

        # guard-rail reward prediction range
        reward_clamp = self.sim_param.reward_clamp
        if reward_clamp is not None:
            sim_slate_reward = torch.clamp(
                sim_slate_reward, min=reward_clamp.clamp_min, max=reward_clamp.clamp_max
            )
        # guard-rail sequence similarity
        distance_penalty = self.sim_param.distance_penalty
        if distance_penalty is not None:
            sim_distance = (
                torch.tensor(
                    # pyre-fixme[16]: `int` has no attribute `__iter__`.
                    [swap_dist(x.tolist()) for x in model_actions],
                    device=device,
                )
                .unsqueeze(1)
                .float()
            )
            sim_slate_reward += distance_penalty * (self.MAX_DISTANCE - sim_distance)

        assert (
            len(sim_slate_reward.shape) == 2 and sim_slate_reward.shape[1] == 1
        ), f"{sim_slate_reward.shape}"

        on_policy_input = rlt.PreprocessedRankingInput.from_input(
            state=training_input.state.float_features,
            candidates=training_input.src_seq.float_features,
            device=device,
            # pyre-fixme[6]: Expected `Optional[torch.Tensor]` for 4th param but got
            #  `int`.
            # pyre-fixme[61]: `model_actions` may not be initialized here.
            action=model_actions,
            slate_reward=sim_slate_reward,
            # pyre-fixme[61]: `model_propensities` may not be initialized here.
            logged_propensities=model_propensities,
        )
        return on_policy_input